Asyncio Socket Pool
In a typical client-server over a socket setup the server may take some time to process any given request and so to allow thee client to fire off multiple requests in parallel to speed things up there are two common paths:
- We can make the server interface asynchronous so that it is able to take a whole series of requests and respond to them out of order.
- The server can support multiple socket connections, but each one is blocking and so we can parallelize communications across the multiple sockets.
The first is more elegant but requires additional work to allow the server to asynchronously dispatch requests and keep responding to new ones and for the client to handle the out of order replies. We took that path with the rpcq client-server For simpler cases, multiple sockets might be sufficient.
Here is a simple implementation of a client socket pool wrapped an an asynccontextmanager
decorator so that
we can nicely async with ...
await
on the next available socket. To bound the number of sockets
we use an asyncio
Semaphore. The semaphore is
an integer that is initialized to the maximum number of sockets. Every time a socket is checked out
from the pool we await
on the semaphore, the integer is decremented by one and then finally the
semaphore blocks when the integer would go negative. When we are done with the socket we return it
to the pool and increment the semaphore integer again. All this integer decrement/increment logic is
wrapped with by the Python asyncio
module with another asynccontextmanager
so all the
bookkeeping so the interface is a simple async with self.semaphore:
. We’ll return the high-level
asyncio StreamReader/StreamWriter
to provide access to the socket.
import asyncio
from contextlib import asynccontextmanager
from typing import Optional, Tuple
class SocketPool():
"""
A pool of socket connections for a client to access some server.
"""
address: str
""" Hostname for the server """
port: int = 1234
""" Port that the instrument is listening for connections on. """
max_num_sockets: int = 10
""" Maximum number of open sockets to the server """
semaphore: asyncio.Semaphore
""" Use a semaphore to bound the number of open connection to `max_num_sockets` """
pool: list
""" List of socket StreamReader, StreamWriter pairs """
def __init__(self, address: str, max_num_sockets: Optional[int] = None, port: Optional[int] = None):
self.address = address
# if we're passed some `max_num_sockets` or `port` then override the defaults
if max_num_sockets is not None:
self.max_num_sockets = max_num_sockets
if port is not None:
self.port = port
self.semaphore = asyncio.Semaphore(self.max_num_sockets)
# a list to hold onto the (asyncio.StreamReader, asyncio.StreamWriter) tuples for each socket
self.pool = []
@asynccontextmanager
async def get_socket(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""
Get the next available socket reader/writer from the pool
"""
async with self.semaphore:
try:
# take the last socket in the pool
reader, writer = self.pool.pop()
except IndexError:
# empty list so create a new socket
_log.info(f"Adding socket connection to server at {self.address}")
reader, writer = await asyncio.open_connection(self.address, self.port)
try:
# hand the reader/writer pair back to the caller
yield reader, writer
finally:
# put the socket back in the pool
self.pool.append((reader, writer))
We can then test this out with a simple echo server that listens for connections and then takes some time to respond.
import asyncio
async def handle_echo(reader, writer):
data = await reader.read(100)
message = data.decode()
addr = writer.get_extra_info('peername')
print(f"Received {message!r} from {addr!r}")
print("Taking a second to process something....")
await asyncio.sleep(1.0)
print("Sending back response.")
writer.write(data)
await writer.drain()
print("Close the connection")
writer.close()
async def main():
server = await asyncio.start_server(handle_echo, "localhost", 1234)
addr = server.sockets[0].getsockname()
print(f"Serving on {addr}")
async with server:
await server.serve_forever()
asyncio.run(main())
And then a simple client side function to use it.
import asyncio
from wherever.socket_pool import SocketPool
async def send_message(pool, msg):
# check out a socket
async with pool.get_socket() as (reader, writer):
# send the message
writer.write(msg.encode())
await writer.drain()
# wait for the echo response
resp = await reader.read(100)
return resp.decode()
pool = SocketPool("localhost" max_num_sockets=5)
(base) cryan@cryan-Precision-5510:~/repos/tomorrow-said-toad/content$ ipython
Python 3.7.4 (default, Aug 13 2019, 20:35:49)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.8.0 -- An enhanced Interactive Python. Type '?' for help.
In [1]: %paste # import the client code
# send a single message and a single socket is created
In [2]: await send_message(pool, "Hello!")
Adding socket connection to server at localhost
Out[2]: 'Hello!'
# we can inspect the pool and see indeed there is one socket there
In [3]: pool.pool
Out[3]:
[(<StreamReader transport=<_SelectorSocketTransport fd=16 read=polling write=<idle, bufsize=0>>>,
<StreamWriter transport=<_SelectorSocketTransport fd=16 read=polling write=<idle, bufsize=0>> reader=<StreamReader transport=<_SelectorSocketTransport fd=16 read=polling write=<idle, bufsize=0>>>>)]
In [4]: len(pool.pool)
Out[4]: 1
# now send a slew of requests at once and it still takes only a second to respond
In [5]: await asyncio.wait([send_message(pool, str(ct)) for ct in range(5)])
Adding socket connection to server at localhost
Adding socket connection to server at localhost
Adding socket connection to server at localhost
Adding socket connection to server at localhost
Out[5]:
({<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='0'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='1'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='2'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='3'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='4'>},
set())
# and we have 5 sockets in the pool
In [6]: len(pool.pool)
Out[6]: 5
# finally send even more requests and it takes about 4 seconds to respond as only 5 requests are outstanding at any time
In [7]: await asyncio.wait([send_message(pool, str(ct)) for ct in range(20)])
Out[7]:
({<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='0'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='1'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='10'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='11'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='12'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='13'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='14'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='15'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='16'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='17'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='18'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='19'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='2'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='3'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='4'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='5'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='6'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='7'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='8'>,
<Task finished coro=<send_message() done, defined at /home/cryan/repos/tomorrow-said-toad/content/client.py:61> result='9'>},
set())
In [8]: len(pool.pool)
Out[8]: 5