diff --git a/docs/backpressure.rst b/docs/backpressure.rst new file mode 100644 index 0000000..dc3a14f --- /dev/null +++ b/docs/backpressure.rst @@ -0,0 +1,38 @@ +Message Queues +============== + +.. currentmodule:: trio_websocket + +.. TODO This file will grow into a "backpressure" document once #65 is complete. + For now it is just deals with userspace buffers, since this is a related + topic. + +When a connection is open, it runs a background task that reads network data and +automatically handles certain types of events for you. For example, if the +background task receives a ping event, then it will automatically send back a +pong event. When the background task receives a message, it places that message +into an internal queue. When you call ``get_message()``, it returns the first +item from this queue. + +If this internal message queue does not have any size limits, then a remote +endpoint could rapidly send large messages and use up all of the memory on the +local machine! In almost all situations, the message queue needs to have size +limits, both in terms of the number of items and the size per message. These +limits create an upper bound for the amount of memory that can be used by a +single WebSocket connection. For example, if the queue size is 10 and the +maximum message size is 1 megabyte, then the connection will use at most 10 +megabytes of memory. + +When the message queue is full, the background task pauses and waits for the +user to remove a message, i.e. call ``get_message()``. When the background task +is paused, it stops processing background events like replying to ping events. +If a message is received that is larger than the maximum message size, then the +connection is automatically closed with code 1009 and the message is discarded. + +The library APIs each take arguments to configure the mesage buffer: +``message_queue_size`` and ``max_message_size``. By default the queue size is +one and the maximum message size is 1 MiB. If you set queue size to zero, then +the background task will block every time it receives a message until somebody +calls ``get_message()``. For an unbounded queue—which is strongly +discouraged—set the queue size to ``math.inf``. Likewise, the maximum message +size may also be disabled by setting it to ``math.inf``. diff --git a/docs/index.rst b/docs/index.rst index a6d7471..1258c4c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ Autobahn Test Suite `__. getting_started clients servers + backpressure timeouts api recipes diff --git a/docs/servers.rst b/docs/servers.rst index f662036..314b674 100644 --- a/docs/servers.rst +++ b/docs/servers.rst @@ -41,7 +41,9 @@ host/port to bind to. The handler function receives a :class:`WebSocketRequest` object, and it calls the request's :func:`~WebSocketRequest.accept` method to finish the handshake and obtain a :class:`WebSocketConnection` object. When the handler function exits, the -connection is automatically closed. +connection is automatically closed. If the handler function raises an +exception, the server will silently close the connection and cancel the +tasks belonging to it. .. autofunction:: serve_websocket diff --git a/tests/test_connection.py b/tests/test_connection.py index a4ce171..4bd70ec 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -53,6 +53,7 @@ HOST = '127.0.0.1' RESOURCE = '/resource' +DEFAULT_TEST_MAX_DURATION = 1 # Timeout tests follow a general pattern: one side waits TIMEOUT seconds for an # event. The other side delays for FORCE_TIMEOUT seconds to force the timeout @@ -60,7 +61,7 @@ # prevent a faulty test from hanging the entire suite. TIMEOUT = 1 FORCE_TIMEOUT = 2 -MAX_TIMEOUT_TEST_DURATION = 3 +TIMEOUT_TEST_MAX_DURATION = 3 @pytest.fixture @@ -89,12 +90,6 @@ async def echo_request_handler(request): Accept incoming request and then pass off to echo connection handler. ''' conn = await request.accept() - await echo_conn_handler(conn) - - -async def echo_conn_handler(conn): - ''' A connection handler that reads one message, sends back the same - message, then exits. ''' try: msg = await conn.get_message() await conn.send_message(msg) @@ -391,7 +386,7 @@ async def handler(stream): await client.send_message('Hello from client!') -@fail_after(MAX_TIMEOUT_TEST_DURATION) +@fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_client_open_timeout(nursery, autojump_clock): ''' The client times out waiting for the server to complete the opening @@ -411,7 +406,7 @@ async def handler(request): pass -@fail_after(MAX_TIMEOUT_TEST_DURATION) +@fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_client_close_timeout(nursery, autojump_clock): ''' This client times out waiting for the server to complete the closing @@ -430,7 +425,8 @@ async def handler(request): pytest.fail('Should not reach this line.') server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None, + message_queue_size=0)) with pytest.raises(trio.TooSlowError): async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, @@ -438,7 +434,7 @@ async def handler(request): await client_ws.send_message('test') -@fail_after(MAX_TIMEOUT_TEST_DURATION) +@fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_server_open_timeout(autojump_clock): ''' The server times out waiting for the client to complete the opening @@ -470,7 +466,7 @@ async def handler(request): nursery.cancel_scope.cancel() -@fail_after(MAX_TIMEOUT_TEST_DURATION) +@fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_server_close_timeout(autojump_clock): ''' The server times out waiting for the client to complete the closing @@ -488,7 +484,7 @@ async def handler(request): ws = await request.accept() # Send one message to block the client's reader task: await ws.send_message('test') - import logging + async with trio.open_nursery() as outer: server = await outer.start(partial(serve_websocket, handler, HOST, 0, ssl_context=None, handler_nursery=outer, @@ -523,7 +519,6 @@ async def handler(request): with pytest.raises(ConnectionClosed): await server_ws.get_message() server = await nursery.start(serve_websocket, handler, HOST, 0, None) - port = server.port stream = await trio.open_tcp_stream(HOST, server.port) client_ws = await wrap_client_stream(nursery, stream, HOST, RESOURCE) async with client_ws: @@ -566,12 +561,14 @@ async def handler(request): assert exc.reason.name == 'NORMAL_CLOSURE' -@pytest.mark.skip(reason='Hangs because channel size is hard coded to 0') +@fail_after(DEFAULT_TEST_MAX_DURATION) async def test_read_messages_after_remote_close(nursery): ''' When the remote endpoint closes, the local endpoint can still read all of the messages sent prior to closing. Any attempt to read beyond that will raise ConnectionClosed. + + This test also exercises the configuration of the queue size. ''' server_closed = trio.Event() @@ -585,7 +582,10 @@ async def handler(request): server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) - async with open_websocket(HOST, server.port, '/', use_ssl=False) as client: + # The client needs a message queue of size 2 so that it can buffer both + # incoming messages without blocking the reader task. + async with open_websocket(HOST, server.port, '/', use_ssl=False, + message_queue_size=2) as client: await server_closed.wait() assert await client.get_message() == '1' assert await client.get_message() == '2' @@ -618,12 +618,49 @@ async def handler(request): client_closed.set() -async def test_client_cm_exit_with_pending_messages(echo_server, autojump_clock): +async def test_cm_exit_with_pending_messages(echo_server, autojump_clock): + ''' + Regression test for #74, where a context manager was not able to exit when + there were pending messages in the receive queue. + ''' with trio.fail_after(1): async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) as ws: await ws.send_message('hello') # allow time for the server to respond await trio.sleep(.1) - # bug: context manager exit is blocked on unconsumed message - #await ws.get_message() + + +@fail_after(DEFAULT_TEST_MAX_DURATION) +async def test_max_message_size(nursery): + ''' + Set the client's max message size to 100 bytes. The client can send a + message larger than 100 bytes, but when it receives a message larger than + 100 bytes, it closes the connection with code 1009. + ''' + async def handler(request): + ''' Similar to the echo_request_handler fixture except it runs in a + loop. ''' + conn = await request.accept() + while True: + try: + msg = await conn.get_message() + await conn.send_message(msg) + except ConnectionClosed: + break + + server = await nursery.start( + partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, + max_message_size=100) as client: + # We can send and receive 100 bytes: + await client.send_message(b'A' * 100) + msg = await client.get_message() + assert len(msg) == 100 + # We can send 101 bytes but cannot receive 101 bytes: + await client.send_message(b'B' * 101) + with pytest.raises(ConnectionClosed): + await client.get_message() + assert client.closed + assert client.closed.code == 1009 diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index f79d7bb..bae7331 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -20,13 +20,16 @@ CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds -RECEIVE_BYTES = 4096 +MESSAGE_QUEUE_SIZE = 1 +MAX_MESSAGE_SIZE = 2 ** 20 # 1 MiB +RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB logger = logging.getLogger('trio-websocket') @asynccontextmanager @async_generator async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, + message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): ''' Open a WebSocket client connection to a host. @@ -44,6 +47,11 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, :type use_ssl: bool or ssl.SSLContext :param subprotocols: An iterable of strings representing preferred subprotocols. + :param int message_queue_size: The maximum number of messages that will be + buffered in the library's internal message queue. + :param int max_message_size: The maximum message size as measured by + ``len()``. If a message is received that is larger than this size, + then the connection is closed with code 1009 (Message Too Big). :param float connect_timeout: The number of seconds to wait for the connection before timing out. :param float disconnect_timeout: The number of seconds to wait when closing @@ -53,7 +61,9 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, async with trio.open_nursery() as new_nursery: with trio.fail_after(connect_timeout): connection = await connect_websocket(new_nursery, host, port, - resource, use_ssl=use_ssl, subprotocols=subprotocols) + resource, use_ssl=use_ssl, subprotocols=subprotocols, + message_queue_size=message_queue_size, + max_message_size=max_message_size) try: await yield_(connection) finally: @@ -62,7 +72,8 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, async def connect_websocket(nursery, host, port, resource, *, use_ssl, - subprotocols=None): + subprotocols=None, message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE): ''' Return an open WebSocket client connection to a host. @@ -79,6 +90,11 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, :type use_ssl: bool or ssl.SSLContext :param subprotocols: An iterable of strings representing preferred subprotocols. + :param int message_queue_size: The maximum number of messages that will be + buffered in the library's internal message queue. + :param int max_message_size: The maximum message size as measured by + ``len()``. If a message is received that is larger than this size, + then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection ''' if use_ssl == True: @@ -103,13 +119,16 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, host_header = '{}:{}'.format(host, port) wsproto = wsconnection.WSConnection(wsconnection.CLIENT, host=host_header, resource=resource, subprotocols=subprotocols) - connection = WebSocketConnection(stream, wsproto, path=resource) + connection = WebSocketConnection(stream, wsproto, path=resource, + message_queue_size=message_queue_size, + max_message_size=max_message_size) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection def open_websocket_url(url, ssl_context=None, *, subprotocols=None, + message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): ''' Open a WebSocket client connection to a URL. @@ -124,6 +143,11 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None, :type ssl_context: ssl.SSLContext or None :param subprotocols: An iterable of strings representing preferred subprotocols. + :param int message_queue_size: The maximum number of messages that will be + buffered in the library's internal message queue. + :param int max_message_size: The maximum message size as measured by + ``len()``. If a message is received that is larger than this size, + then the connection is closed with code 1009 (Message Too Big). :param float connect_timeout: The number of seconds to wait for the connection before timing out. :param float disconnect_timeout: The number of seconds to wait when closing @@ -132,11 +156,13 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None, ''' host, port, resource, ssl_context = _url_to_host(url, ssl_context) return open_websocket(host, port, resource, use_ssl=ssl_context, - subprotocols=subprotocols) + subprotocols=subprotocols, message_queue_size=message_queue_size, + max_message_size=max_message_size) async def connect_websocket_url(nursery, url, ssl_context=None, *, - subprotocols=None): + subprotocols=None, message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE): ''' Return an open WebSocket client connection to a URL. @@ -146,17 +172,24 @@ async def connect_websocket_url(nursery, url, ssl_context=None, *, If you don't need a custom nursery, you should probably use :func:`open_websocket_url` instead. + :param nursery: A nursery to run background tasks in. :param str url: A WebSocket URL. :param ssl_context: Optional SSL context used for ``wss:`` URLs. :type ssl_context: ssl.SSLContext or None - :param nursery: A nursery to run background tasks in. :param subprotocols: An iterable of strings representing preferred subprotocols. + :param int message_queue_size: The maximum number of messages that will be + buffered in the library's internal message queue. + :param int max_message_size: The maximum message size as measured by + ``len()``. If a message is received that is larger than this size, + then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection ''' host, port, resource, ssl_context = _url_to_host(url, ssl_context) return await connect_websocket(nursery, host, port, resource, - use_ssl=ssl_context, subprotocols=subprotocols) + use_ssl=ssl_context, subprotocols=subprotocols, + message_queue_size=message_queue_size, + max_message_size=max_message_size) def _url_to_host(url, ssl_context): @@ -182,7 +215,8 @@ def _url_to_host(url, ssl_context): async def wrap_client_stream(nursery, stream, host, resource, *, - subprotocols=None): + subprotocols=None, message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE): ''' Wrap an arbitrary stream in a WebSocket connection. @@ -197,17 +231,25 @@ async def wrap_client_stream(nursery, stream, host, resource, *, accessed on the server. :param subprotocols: An iterable of strings representing preferred subprotocols. + :param int message_queue_size: The maximum number of messages that will be + buffered in the library's internal message queue. + :param int max_message_size: The maximum message size as measured by + ``len()``. If a message is received that is larger than this size, + then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection ''' wsproto = wsconnection.WSConnection(wsconnection.CLIENT, host=host, resource=resource, subprotocols=subprotocols) - connection = WebSocketConnection(stream, wsproto, path=resource) + connection = WebSocketConnection(stream, wsproto, path=resource, + message_queue_size=message_queue_size, + max_message_size=max_message_size) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection -async def wrap_server_stream(nursery, stream): +async def wrap_server_stream(nursery, stream, + message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): ''' Wrap an arbitrary stream in a server-side WebSocket. @@ -216,18 +258,26 @@ async def wrap_server_stream(nursery, stream): :param nursery: A nursery to run background tasks in. :param stream: A stream to be wrapped. + :param int message_queue_size: The maximum number of messages that will be + buffered in the library's internal message queue. + :param int max_message_size: The maximum message size as measured by + ``len()``. If a message is received that is larger than this size, + then the connection is closed with code 1009 (Message Too Big). :type stream: trio.abc.Stream :rtype: WebSocketConnection ''' wsproto = wsconnection.WSConnection(wsconnection.SERVER) - connection = WebSocketConnection(stream, wsproto) + connection = WebSocketConnection(stream, wsproto, + message_queue_size=message_queue_size, + max_message_size=max_message_size) nursery.start_soon(connection._reader_task) request = await connection._get_request() return request async def serve_websocket(handler, host, port, ssl_context, *, - handler_nursery=None, connect_timeout=CONN_TIMEOUT, + handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT, task_status=trio.TASK_STATUS_IGNORED): ''' Serve a WebSocket over TCP. @@ -252,6 +302,11 @@ async def serve_websocket(handler, host, port, ssl_context, *, :type ssl_context: ssl.SSLContext or None :param handler_nursery: An optional nursery to spawn handlers and background tasks in. If not specified, a new nursery will be created internally. + :param int message_queue_size: The maximum number of messages that will be + buffered in the library's internal message queue. + :param int max_message_size: The maximum message size as measured by + ``len()``. If a message is received that is larger than this size, + then the connection is closed with code 1009 (Message Too Big). :param float connect_timeout: The number of seconds to wait for a client to finish connection handshake before timing out. :param float disconnect_timeout: The number of seconds to wait for a client @@ -266,7 +321,8 @@ async def serve_websocket(handler, host, port, ssl_context, *, ssl_context, host=host, https_compatible=True) listeners = await open_tcp_listeners() server = WebSocketServer(handler, listeners, - handler_nursery=handler_nursery, connect_timeout=connect_timeout, + handler_nursery=handler_nursery, message_queue_size=message_queue_size, + max_message_size=max_message_size, connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout) await server.run(task_status=task_status) @@ -442,7 +498,9 @@ class WebSocketConnection(trio.abc.AsyncResource): CONNECTION_ID = itertools.count() - def __init__(self, stream, wsproto, *, path=None): + def __init__(self, stream, wsproto, *, path=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE): ''' Constructor. @@ -458,17 +516,25 @@ def __init__(self, stream, wsproto, *, path=None): :type wsproto: wsproto.connection.WSConnection :param str path: The URL path for this connection. Only used for server instances. + :param int message_queue_size: The maximum number of messages that will be + buffered in the library's internal message queue. + :param int max_message_size: The maximum message size as measured by + ``len()``. If a message is received that is larger than this size, + then the connection is closed with code 1009 (Message Too Big). ''' self._close_reason = None self._id = next(self.__class__.CONNECTION_ID) self._stream = stream self._stream_lock = trio.StrictFIFOLock() self._wsproto = wsproto + self._message_size = 0 self._message_parts = [] # type: List[bytes|str] + self._max_message_size = max_message_size self._reader_running = True self._path = path self._subprotocol = None - self._send_channel, self._recv_channel = trio.open_memory_channel(0) + self._send_channel, self._recv_channel = trio.open_memory_channel( + message_queue_size) self._pings = OrderedDict() # Set when the server has received a connection request event. This # future is never set on client connections. @@ -552,6 +618,7 @@ async def aclose(self, code=1000, reason=None): # stream is closed. await self._close_stream() + async def get_message(self): ''' Receive the next WebSocket message. @@ -760,10 +827,22 @@ async def _handle_data_received_event(self, event): :param event: ''' + self._message_size += len(event.data) self._message_parts.append(event.data) - if event.message_finished: + if self._message_size > self._max_message_size: + err = 'Exceeded maximum message size: {} bytes'.format( + self._max_message_size) + self._message_size = 0 + self._message_parts = [] + self._close_reason = CloseReason(1009, err) + self._wsproto.close(code=1009, reason=err) + await self._write_pending() + await self._recv_channel.aclose() + self._reader_running = False + elif event.message_finished: msg = (b'' if isinstance(event, BytesReceived) else '') \ .join(self._message_parts) + self._message_size = 0 self._message_parts = [] try: await self._send_channel.send(msg) @@ -860,7 +939,8 @@ async def _reader_task(self): break else: logger.debug('%s received %d bytes', self, len(data)) - self._wsproto.receive_bytes(data) + if not self._wsproto.closed: + self._wsproto.receive_bytes(data) logger.debug('%s reader task finished', self) @@ -907,7 +987,9 @@ class WebSocketServer: ''' def __init__(self, handler, listeners, *, handler_nursery=None, - connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT): ''' Constructor. @@ -933,6 +1015,8 @@ def __init__(self, handler, listeners, *, handler_nursery=None, self._handler = handler self._handler_nursery = handler_nursery self._listeners = listeners + self._message_queue_size = message_queue_size + self._max_message_size = max_message_size self._connect_timeout = connect_timeout self._disconnect_timeout = disconnect_timeout @@ -1013,7 +1097,9 @@ async def _handle_connection(self, stream): ''' async with trio.open_nursery() as nursery: wsproto = wsconnection.WSConnection(wsconnection.SERVER) - connection = WebSocketConnection(stream, wsproto) + connection = WebSocketConnection(stream, wsproto, + message_queue_size=self._message_queue_size, + max_message_size=self._max_message_size) nursery.start_soon(connection._reader_task) with trio.move_on_after(self._connect_timeout) as connect_scope: request = await connection._get_request()