From 7a9bf5f9f0717cb316b40a83484fb51c2cf0ea89 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 10 May 2023 14:02:02 -0700 Subject: [PATCH 01/11] Blacken --- tests/test_connection.py | 526 ++++++++++++++++------------ trio_websocket/_impl.py | 699 ++++++++++++++++++++++--------------- trio_websocket/_version.py | 2 +- 3 files changed, 727 insertions(+), 500 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 79bb9b4..1f29279 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,4 @@ -''' +""" Unit tests for trio_websocket. Many of these tests involve networking, i.e. real TCP sockets. To maximize @@ -28,7 +28,7 @@ call ``ws.get_message()`` without actually sending it a message. This will cause the server to block until the client has sent the closing handshake. In other circumstances -''' +""" from functools import partial, wraps import ssl from unittest.mock import patch @@ -61,13 +61,13 @@ WebSocketServer, WebSocketRequest, wrap_client_stream, - wrap_server_stream + wrap_server_stream, ) -WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.'))) +WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split("."))) -HOST = '127.0.0.1' -RESOURCE = '/resource' +HOST = "127.0.0.1" +RESOURCE = "/resource" DEFAULT_TEST_MAX_DURATION = 1 # Timeout tests follow a general pattern: one side waits TIMEOUT seconds for an @@ -81,27 +81,25 @@ @pytest.fixture async def echo_server(nursery): - ''' A server that reads one message, sends back the same message, - then closes the connection. ''' - serve_fn = partial(serve_websocket, echo_request_handler, HOST, 0, - ssl_context=None) + """A server that reads one message, sends back the same message, + then closes the connection.""" + serve_fn = partial(serve_websocket, echo_request_handler, HOST, 0, ssl_context=None) server = await nursery.start(serve_fn) yield server @pytest.fixture async def echo_conn(echo_server): - ''' Return a client connection instance that is connected to an echo - server. ''' - async with open_websocket(HOST, echo_server.port, RESOURCE, - use_ssl=False) as conn: + """Return a client connection instance that is connected to an echo + server.""" + async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) as conn: yield conn async def echo_request_handler(request): - ''' + """ Accept incoming request and then pass off to echo connection handler. - ''' + """ conn = await request.accept() try: msg = await conn.get_message() @@ -111,8 +109,9 @@ async def echo_request_handler(request): class fail_after: - ''' This decorator fails if the runtime of the decorated function (as - measured by the Trio clock) exceeds the specified value. ''' + """This decorator fails if the runtime of the decorated function (as + measured by the Trio clock) exceeds the specified value.""" + def __init__(self, seconds): self._seconds = seconds @@ -122,7 +121,10 @@ async def wrapper(*args, **kwargs): with trio.move_on_after(self._seconds) as cancel_scope: await fn(*args, **kwargs) if cancel_scope.cancelled_caught: - pytest.fail(f'Test runtime exceeded the maximum {self._seconds} seconds') + pytest.fail( + f"Test runtime exceeded the maximum {self._seconds} seconds" + ) + return wrapper @@ -154,41 +156,41 @@ async def aclose(self): async def test_endpoint_ipv4(): - e1 = Endpoint('10.105.0.2', 80, False) - assert e1.url == 'ws://10.105.0.2' + e1 = Endpoint("10.105.0.2", 80, False) + assert e1.url == "ws://10.105.0.2" assert str(e1) == 'Endpoint(address="10.105.0.2", port=80, is_ssl=False)' - e2 = Endpoint('127.0.0.1', 8000, False) - assert e2.url == 'ws://127.0.0.1:8000' + e2 = Endpoint("127.0.0.1", 8000, False) + assert e2.url == "ws://127.0.0.1:8000" assert str(e2) == 'Endpoint(address="127.0.0.1", port=8000, is_ssl=False)' - e3 = Endpoint('0.0.0.0', 443, True) - assert e3.url == 'wss://0.0.0.0' + e3 = Endpoint("0.0.0.0", 443, True) + assert e3.url == "wss://0.0.0.0" assert str(e3) == 'Endpoint(address="0.0.0.0", port=443, is_ssl=True)' async def test_listen_port_ipv6(): - e1 = Endpoint('2599:8807:6201:b7:16cf:bb9c:a6d3:51ab', 80, False) - assert e1.url == 'ws://[2599:8807:6201:b7:16cf:bb9c:a6d3:51ab]' - assert str(e1) == 'Endpoint(address="2599:8807:6201:b7:16cf:bb9c:a6d3' \ - ':51ab", port=80, is_ssl=False)' - e2 = Endpoint('::1', 8000, False) - assert e2.url == 'ws://[::1]:8000' + e1 = Endpoint("2599:8807:6201:b7:16cf:bb9c:a6d3:51ab", 80, False) + assert e1.url == "ws://[2599:8807:6201:b7:16cf:bb9c:a6d3:51ab]" + assert ( + str(e1) == 'Endpoint(address="2599:8807:6201:b7:16cf:bb9c:a6d3' + ':51ab", port=80, is_ssl=False)' + ) + e2 = Endpoint("::1", 8000, False) + assert e2.url == "ws://[::1]:8000" assert str(e2) == 'Endpoint(address="::1", port=8000, is_ssl=False)' - e3 = Endpoint('::', 443, True) - assert e3.url == 'wss://[::]' + e3 = Endpoint("::", 443, True) + assert e3.url == "wss://[::]" assert str(e3) == 'Endpoint(address="::", port=443, is_ssl=True)' async def test_server_has_listeners(nursery): - server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, - None) + server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, None) assert len(server.listeners) > 0 assert isinstance(server.listeners[0], Endpoint) async def test_serve(nursery): task = current_task() - server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, - None) + server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, None) port = server.port assert server.port != 0 # The server nursery begins with one task (server.listen). @@ -209,11 +211,11 @@ async def test_serve_ssl(nursery): cert = ca.issue_server_cert(HOST) cert.configure_cert(server_context) - server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, - server_context) + server = await nursery.start( + serve_websocket, echo_request_handler, HOST, 0, server_context + ) port = server.port - async with open_websocket(HOST, port, RESOURCE, use_ssl=client_context - ) as conn: + async with open_websocket(HOST, port, RESOURCE, use_ssl=client_context) as conn: assert not conn.closed assert conn.local.is_ssl assert conn.remote.is_ssl @@ -222,8 +224,14 @@ async def test_serve_ssl(nursery): async def test_serve_handler_nursery(nursery): task = current_task() async with trio.open_nursery() as handler_nursery: - serve_with_nursery = partial(serve_websocket, echo_request_handler, - HOST, 0, None, handler_nursery=handler_nursery) + serve_with_nursery = partial( + serve_websocket, + echo_request_handler, + HOST, + 0, + None, + handler_nursery=handler_nursery, + ) server = await nursery.start(serve_with_nursery) port = server.port # The server nursery begins with one task (server.listen). @@ -248,7 +256,7 @@ async def test_serve_non_tcp_listener(nursery): assert len(server.listeners) == 1 with pytest.raises(RuntimeError): server.port # pylint: disable=pointless-statement - assert server.listeners[0].startswith('MemoryListener(') + assert server.listeners[0].startswith("MemoryListener(") async def test_serve_multiple_listeners(nursery): @@ -265,74 +273,77 @@ async def test_serve_multiple_listeners(nursery): assert server.listeners[0].port != 0 # The second listener metadata is a string containing the repr() of a # MemoryListener object. - assert server.listeners[1].startswith('MemoryListener(') + assert server.listeners[1].startswith("MemoryListener(") async def test_client_open(echo_server): - async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) \ - as conn: + async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) as conn: assert not conn.closed assert conn.is_client - assert str(conn).startswith('client-') + assert str(conn).startswith("client-") -@pytest.mark.parametrize('path, expected_path', [ - ('/', '/'), - ('', '/'), - (RESOURCE + '/path', RESOURCE + '/path'), - (RESOURCE + '?foo=bar', RESOURCE + '?foo=bar') -]) +@pytest.mark.parametrize( + "path, expected_path", + [ + ("/", "/"), + ("", "/"), + (RESOURCE + "/path", RESOURCE + "/path"), + (RESOURCE + "?foo=bar", RESOURCE + "?foo=bar"), + ], +) async def test_client_open_url(path, expected_path, echo_server): - url = f'ws://{HOST}:{echo_server.port}{path}' + url = f"ws://{HOST}:{echo_server.port}{path}" async with open_websocket_url(url) as conn: assert conn.path == expected_path async def test_client_open_invalid_url(echo_server): with pytest.raises(ValueError): - async with open_websocket_url('http://foo.com/bar') as conn: + async with open_websocket_url("http://foo.com/bar") as conn: pass async def test_ascii_encoded_path_is_ok(echo_server): - path = '%D7%90%D7%91%D7%90?%D7%90%D7%9E%D7%90' - url = f'ws://{HOST}:{echo_server.port}{RESOURCE}/{path}' + path = "%D7%90%D7%91%D7%90?%D7%90%D7%9E%D7%90" + url = f"ws://{HOST}:{echo_server.port}{RESOURCE}/{path}" async with open_websocket_url(url) as conn: - assert conn.path == RESOURCE + '/' + path + assert conn.path == RESOURCE + "/" + path -@patch('trio_websocket._impl.open_websocket') +@patch("trio_websocket._impl.open_websocket") def test_client_open_url_options(open_websocket_mock): """open_websocket_url() must pass its options on to open_websocket()""" port = 1234 - url = f'ws://{HOST}:{port}{RESOURCE}' + url = f"ws://{HOST}:{port}{RESOURCE}" options = { - 'subprotocols': ['chat'], - 'extra_headers': [(b'X-Test-Header', b'My test header')], - 'message_queue_size': 9, - 'max_message_size': 333, - 'connect_timeout': 36, - 'disconnect_timeout': 37, + "subprotocols": ["chat"], + "extra_headers": [(b"X-Test-Header", b"My test header")], + "message_queue_size": 9, + "max_message_size": 333, + "connect_timeout": 36, + "disconnect_timeout": 37, } open_websocket_url(url, **options) _, call_args, call_kwargs = open_websocket_mock.mock_calls[0] assert call_args == (HOST, port, RESOURCE) - assert not call_kwargs.pop('use_ssl') + assert not call_kwargs.pop("use_ssl") assert call_kwargs == options - open_websocket_url(url.replace('ws:', 'wss:')) + open_websocket_url(url.replace("ws:", "wss:")) _, call_args, call_kwargs = open_websocket_mock.mock_calls[1] - assert call_kwargs['use_ssl'] + assert call_kwargs["use_ssl"] async def test_client_connect(echo_server, nursery): - conn = await connect_websocket(nursery, HOST, echo_server.port, RESOURCE, - use_ssl=False) + conn = await connect_websocket( + nursery, HOST, echo_server.port, RESOURCE, use_ssl=False + ) assert not conn.closed async def test_client_connect_url(echo_server, nursery): - url = f'ws://{HOST}:{echo_server.port}{RESOURCE}' + url = f"ws://{HOST}:{echo_server.port}{RESOURCE}" conn = await connect_websocket_url(nursery, url) assert not conn.closed @@ -361,21 +372,21 @@ async def handler(request): conn = await request.accept() server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False - ) as client_ws: + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False) as client_ws: pass async def test_handshake_subprotocol(nursery): async def handler(request): - assert request.proposed_subprotocols == ('chat', 'file') - server_ws = await request.accept(subprotocol='chat') - assert server_ws.subprotocol == 'chat' + assert request.proposed_subprotocols == ("chat", "file") + server_ws = await request.accept(subprotocol="chat") + assert server_ws.subprotocol == "chat" server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - subprotocols=('chat', 'file')) as client_ws: - assert client_ws.subprotocol == 'chat' + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, subprotocols=("chat", "file") + ) as client_ws: + assert client_ws.subprotocol == "chat" async def test_handshake_path(nursery): @@ -385,8 +396,12 @@ async def handler(request): assert server_ws.path == RESOURCE server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - ) as client_ws: + async with open_websocket( + HOST, + server.port, + RESOURCE, + use_ssl=False, + ) as client_ws: assert client_ws.path == RESOURCE @@ -394,107 +409,118 @@ async def handler(request): async def test_handshake_client_headers(nursery): async def handler(request): headers = dict(request.headers) - assert b'x-test-header' in headers - assert headers[b'x-test-header'] == b'My test header' + assert b"x-test-header" in headers + assert headers[b"x-test-header"] == b"My test header" server_ws = await request.accept() - await server_ws.send_message('test') + await server_ws.send_message("test") server = await nursery.start(serve_websocket, handler, HOST, 0, None) - headers = [(b'X-Test-Header', b'My test header')] - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - extra_headers=headers) as client_ws: + headers = [(b"X-Test-Header", b"My test header")] + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, extra_headers=headers + ) as client_ws: await client_ws.get_message() @fail_after(1) async def test_handshake_server_headers(nursery): async def handler(request): - headers = [('X-Test-Header', 'My test header')] + headers = [("X-Test-Header", "My test header")] server_ws = await request.accept(extra_headers=headers) server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False - ) as client_ws: + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False) as client_ws: header_key, header_value = client_ws.handshake_headers[0] - assert header_key == b'x-test-header' - assert header_value == b'My test header' + assert header_key == b"x-test-header" + assert header_value == b"My test header" @fail_after(1) async def test_handshake_exception_before_accept(): - ''' In #107, a request handler that throws an exception before finishing the + """In #107, a request handler that throws an exception before finishing the handshake causes the task to hang. The proper behavior is to raise an - exception to the nursery as soon as possible. ''' + exception to the nursery as soon as possible.""" + async def handler(request): raise ValueError() with pytest.raises(ValueError): async with trio.open_nursery() as nursery: - server = await nursery.start(serve_websocket, handler, HOST, 0, - None) - async with open_websocket(HOST, server.port, RESOURCE, - use_ssl=False) as client_ws: + server = await nursery.start(serve_websocket, handler, HOST, 0, None) + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False + ) as client_ws: pass @fail_after(1) async def test_reject_handshake(nursery): async def handler(request): - body = b'My body' + body = b"My body" await request.reject(400, body=body) server = await nursery.start(serve_websocket, handler, HOST, 0, None) with pytest.raises(ConnectionRejected) as exc_info: - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - ) as client_ws: + async with open_websocket( + HOST, + server.port, + RESOURCE, + use_ssl=False, + ) as client_ws: pass exc = exc_info.value - assert exc.body == b'My body' + assert exc.body == b"My body" @fail_after(1) async def test_reject_handshake_invalid_info_status(nursery): - ''' + """ An informational status code that is not 101 should cause the client to reject the handshake. Since it is an informational response, there will not be a response body, so this test exercises a different code path. - ''' + """ + async def handler(stream): - await stream.send_all(b'HTTP/1.1 100 CONTINUE\r\n\r\n') + await stream.send_all(b"HTTP/1.1 100 CONTINUE\r\n\r\n") await stream.receive_some(max_bytes=1024) + serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) listeners = await nursery.start(serve_fn) port = listeners[0].socket.getsockname()[1] with pytest.raises(ConnectionRejected) as exc_info: - async with open_websocket(HOST, port, RESOURCE, use_ssl=False, - ) as client_ws: + async with open_websocket( + HOST, + port, + RESOURCE, + use_ssl=False, + ) as client_ws: pass exc = exc_info.value assert exc.status_code == 100 - assert repr(exc) == 'ConnectionRejected' + assert repr(exc) == "ConnectionRejected" assert exc.body is None async def test_handshake_protocol_error(nursery, echo_server): - ''' + """ If a client connects to a trio-websocket server and tries to speak HTTP instead of WebSocket, the server should reject the connection. (If the server does not catch the protocol exception, it will raise an exception up to the nursery level and fail the test.) - ''' + """ client_stream = await trio.open_tcp_stream(HOST, echo_server.port) async with client_stream: - await client_stream.send_all(b'GET / HTTP/1.1\r\n\r\n') + await client_stream.send_all(b"GET / HTTP/1.1\r\n\r\n") response = await client_stream.receive_some(1024) - assert response.startswith(b'HTTP/1.1 400') + assert response.startswith(b"HTTP/1.1 400") async def test_client_send_and_receive(echo_conn): async with echo_conn: - await echo_conn.send_message('This is a test message.') + await echo_conn.send_message("This is a test message.") received_msg = await echo_conn.get_message() - assert received_msg == 'This is a test message.' + assert received_msg == "This is a test message." async def test_client_send_invalid_type(echo_conn): @@ -505,17 +531,19 @@ async def test_client_send_invalid_type(echo_conn): async def test_client_ping(echo_conn): async with echo_conn: - await echo_conn.ping(b'A') + await echo_conn.ping(b"A") with pytest.raises(ConnectionClosed): - await echo_conn.ping(b'B') + await echo_conn.ping(b"B") async def test_client_ping_two_payloads(echo_conn): pong_count = 0 + async def ping_and_count(): nonlocal pong_count await echo_conn.ping() pong_count += 1 + async with echo_conn: async with trio.open_nursery() as nursery: nursery.start_soon(ping_and_count) @@ -528,12 +556,14 @@ async def test_client_ping_same_payload(echo_conn): # same time. One of them should succeed and the other should get an # exception. exc_count = 0 + async def ping_and_catch(): nonlocal exc_count try: - await echo_conn.ping(b'A') + await echo_conn.ping(b"A") except ValueError: exc_count += 1 + async with echo_conn: async with trio.open_nursery() as nursery: nursery.start_soon(ping_and_catch) @@ -543,9 +573,9 @@ async def ping_and_catch(): async def test_client_pong(echo_conn): async with echo_conn: - await echo_conn.pong(b'A') + await echo_conn.pong(b"A") with pytest.raises(ConnectionClosed): - await echo_conn.pong(b'B') + await echo_conn.pong(b"B") async def test_client_default_close(echo_conn): @@ -553,16 +583,18 @@ async def test_client_default_close(echo_conn): assert not echo_conn.closed assert echo_conn.closed.code == 1000 assert echo_conn.closed.reason is None - assert repr(echo_conn.closed) == 'CloseReason' + assert ( + repr(echo_conn.closed) == "CloseReason" + ) async def test_client_nondefault_close(echo_conn): async with echo_conn: assert not echo_conn.closed - await echo_conn.aclose(code=1001, reason='test reason') + await echo_conn.aclose(code=1001, reason="test reason") assert echo_conn.closed.code == 1001 - assert echo_conn.closed.reason == 'test reason' + assert echo_conn.closed.reason == "test reason" async def test_wrap_client_stream(nursery): @@ -573,10 +605,10 @@ async def test_wrap_client_stream(nursery): conn = await wrap_client_stream(nursery, stream, HOST, RESOURCE) async with conn: assert not conn.closed - await conn.send_message('Hello from client!') + await conn.send_message("Hello from client!") msg = await conn.get_message() - assert msg == 'Hello from client!' - assert conn.local.startswith('StapledStream(') + assert msg == "Hello from client!" + assert conn.local.startswith("StapledStream(") assert conn.closed @@ -587,38 +619,42 @@ async def handler(stream): async with server_ws: assert not server_ws.closed msg = await server_ws.get_message() - assert msg == 'Hello from client!' + assert msg == "Hello from client!" assert server_ws.closed + serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) listeners = await nursery.start(serve_fn) port = listeners[0].socket.getsockname()[1] async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as client: - await client.send_message('Hello from client!') + await client.send_message("Hello from client!") @fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_client_open_timeout(nursery, autojump_clock): - ''' + """ The client times out waiting for the server to complete the opening handshake. - ''' + """ + async def handler(request): await trio.sleep(FORCE_TIMEOUT) server_ws = await request.accept() - pytest.fail('Should not reach this line.') + pytest.fail("Should not reach this line.") server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) with pytest.raises(ConnectionTimeout): - async with open_websocket(HOST, server.port, '/', use_ssl=False, - connect_timeout=TIMEOUT) as client_ws: + async with open_websocket( + HOST, server.port, "/", use_ssl=False, connect_timeout=TIMEOUT + ) as client_ws: pass @fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_client_close_timeout(nursery, autojump_clock): - ''' + """ This client times out waiting for the server to complete the closing handshake. @@ -626,68 +662,83 @@ async def test_client_close_timeout(nursery, autojump_clock): queue size is 0, and the client sends it exactly 1 message. This blocks the server's reader so it won't do the closing handshake for at least ``FORCE_TIMEOUT`` seconds. - ''' + """ + async def handler(request): server_ws = await request.accept() await trio.sleep(FORCE_TIMEOUT) # The next line should raise ConnectionClosed. await server_ws.get_message() - pytest.fail('Should not reach this line.') + pytest.fail("Should not reach this line.") server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None, - message_queue_size=0)) + partial( + serve_websocket, handler, HOST, 0, ssl_context=None, message_queue_size=0 + ) + ) with pytest.raises(DisconnectionTimeout): - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - disconnect_timeout=TIMEOUT) as client_ws: - await client_ws.send_message('test') + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, disconnect_timeout=TIMEOUT + ) as client_ws: + await client_ws.send_message("test") async def test_client_connect_networking_error(): - with patch('trio_websocket._impl.connect_websocket') as \ - connect_websocket_mock: + with patch("trio_websocket._impl.connect_websocket") as connect_websocket_mock: connect_websocket_mock.side_effect = OSError() with pytest.raises(HandshakeError): - async with open_websocket(HOST, 0, '/', use_ssl=False) as client_ws: + async with open_websocket(HOST, 0, "/", use_ssl=False) as client_ws: pass @fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_server_open_timeout(autojump_clock): - ''' + """ The server times out waiting for the client to complete the opening handshake. Server timeouts don't raise exceptions, because handler tasks are launched in an internal nursery and sending exceptions wouldn't be helpful. Instead, timed out tasks silently end. - ''' + """ + async def handler(request): - pytest.fail('This handler should not be called.') + pytest.fail("This handler should not be called.") async with trio.open_nursery() as nursery: - server = await nursery.start(partial(serve_websocket, handler, HOST, 0, - ssl_context=None, handler_nursery=nursery, connect_timeout=TIMEOUT)) + server = await nursery.start( + partial( + serve_websocket, + handler, + HOST, + 0, + ssl_context=None, + handler_nursery=nursery, + connect_timeout=TIMEOUT, + ) + ) old_task_count = len(nursery.child_tasks) # This stream is not a WebSocket, so it won't send a handshake: stream = await trio.open_tcp_stream(HOST, server.port) # Checkpoint so the server's handler task can spawn: await trio.sleep(0) - assert len(nursery.child_tasks) == old_task_count + 1, \ - "Server's reader task did not spawn" + assert ( + len(nursery.child_tasks) == old_task_count + 1 + ), "Server's reader task did not spawn" # Sleep long enough to trigger server's connect_timeout: await trio.sleep(FORCE_TIMEOUT) - assert len(nursery.child_tasks) == old_task_count, \ - "Server's reader task is still running" + assert ( + len(nursery.child_tasks) == old_task_count + ), "Server's reader task is still running" # Cancel the server task: nursery.cancel_scope.cancel() @fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_server_close_timeout(autojump_clock): - ''' + """ The server times out waiting for the client to complete the closing handshake. @@ -698,33 +749,45 @@ async def test_server_close_timeout(autojump_clock): To prevent the client from doing the closing handshake, we make sure that its message queue size is 0 and the server sends it exactly 1 message. This blocks the client's reader and prevents it from doing the client handshake. - ''' + """ + async def handler(request): ws = await request.accept() # Send one message to block the client's reader task: - await ws.send_message('test') + await ws.send_message("test") async with trio.open_nursery() as outer: - server = await outer.start(partial(serve_websocket, handler, HOST, 0, - ssl_context=None, handler_nursery=outer, - disconnect_timeout=TIMEOUT)) + server = await outer.start( + partial( + serve_websocket, + handler, + HOST, + 0, + ssl_context=None, + handler_nursery=outer, + disconnect_timeout=TIMEOUT, + ) + ) old_task_count = len(outer.child_tasks) # Spawn client inside an inner nursery so that we can cancel it's reader # so that it won't do a closing handshake. async with trio.open_nursery() as inner: - ws = await connect_websocket(inner, HOST, server.port, RESOURCE, - use_ssl=False) + ws = await connect_websocket( + inner, HOST, server.port, RESOURCE, use_ssl=False + ) # Checkpoint so the server can spawn a handler task: await trio.sleep(0) - assert len(outer.child_tasks) == old_task_count + 1, \ - "Server's reader task did not spawn" + assert ( + len(outer.child_tasks) == old_task_count + 1 + ), "Server's reader task did not spawn" # The client waits long enough to trigger the server's disconnect # timeout: await trio.sleep(FORCE_TIMEOUT) # The server should have cancelled the handler: - assert len(outer.child_tasks) == old_task_count, \ - "Server's reader task is still running" + assert ( + len(outer.child_tasks) == old_task_count + ), "Server's reader task is still running" # Cancel the client's reader task: inner.cancel_scope.cancel() @@ -737,13 +800,14 @@ async def handler(request): server_ws = await request.accept() with pytest.raises(ConnectionClosed): await server_ws.get_message() + server = await nursery.start(serve_websocket, handler, HOST, 0, None) stream = await trio.open_tcp_stream(HOST, server.port) client_ws = await wrap_client_stream(nursery, stream, HOST, RESOURCE) async with client_ws: await stream.aclose() with pytest.raises(ConnectionClosed): - await client_ws.send_message('Hello from client!') + await client_ws.send_message("Hello from client!") async def test_server_sends_after_close(nursery): @@ -753,7 +817,7 @@ async def handler(request): server_ws = await request.accept() with pytest.raises(ConnectionClosed): while True: - await server_ws.send_message('Hello from server') + await server_ws.send_message("Hello from server") done.set() server = await nursery.start(serve_websocket, handler, HOST, 0, None) @@ -762,7 +826,7 @@ async def handler(request): async with client_ws: # pump a few messages for x in range(2): - await client_ws.send_message('Hello from client') + await client_ws.send_message("Hello from client") await stream.aclose() await done.wait() @@ -774,7 +838,8 @@ async def handler(stream): async with server_ws: await stream.aclose() with pytest.raises(ConnectionClosed): - await server_ws.send_message('Hello from client!') + await server_ws.send_message("Hello from client!") + serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) listeners = await nursery.start(serve_fn) port = listeners[0].socket.getsockname()[1] @@ -789,69 +854,72 @@ async def handler(request): await trio.sleep(1) server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) # connection should close when server handler exits with trio.fail_after(2): - async with open_websocket( - HOST, server.port, '/', use_ssl=False) as connection: + async with open_websocket(HOST, server.port, "/", use_ssl=False) as connection: with pytest.raises(ConnectionClosed) as exc_info: await connection.get_message() exc = exc_info.value - assert exc.reason.name == 'NORMAL_CLOSURE' + assert exc.reason.name == "NORMAL_CLOSURE" @fail_after(DEFAULT_TEST_MAX_DURATION) async def test_read_messages_after_remote_close(nursery): - ''' + """ When the remote endpoint closes, the local endpoint can still read all of the messages sent prior to closing. Any attempt to read beyond that will raise ConnectionClosed. This test also exercises the configuration of the queue size. - ''' + """ server_closed = trio.Event() async def handler(request): server = await request.accept() async with server: - await server.send_message('1') - await server.send_message('2') + await server.send_message("1") + await server.send_message("2") server_closed.set() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) # The client needs a message queue of size 2 so that it can buffer both # incoming messages without blocking the reader task. - async with open_websocket(HOST, server.port, '/', use_ssl=False, - message_queue_size=2) as client: + async with open_websocket( + HOST, server.port, "/", use_ssl=False, message_queue_size=2 + ) as client: await server_closed.wait() - assert await client.get_message() == '1' - assert await client.get_message() == '2' + assert await client.get_message() == "1" + assert await client.get_message() == "2" with pytest.raises(ConnectionClosed): await client.get_message() async def test_no_messages_after_local_close(nursery): - ''' + """ If the local endpoint initiates closing, then pending messages are discarded and any attempt to read a message will raise ConnectionClosed. - ''' + """ client_closed = trio.Event() async def handler(request): # The server sends some messages and then closes. server = await request.accept() async with server: - await server.send_message('1') - await server.send_message('2') + await server.send_message("1") + await server.send_message("2") await client_closed.wait() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - async with open_websocket(HOST, server.port, '/', use_ssl=False) as client: + async with open_websocket(HOST, server.port, "/", use_ssl=False) as client: pass with pytest.raises(ConnectionClosed): await client.get_message() @@ -859,28 +927,30 @@ async def handler(request): async def test_cm_exit_with_pending_messages(echo_server, autojump_clock): - ''' + """ Regression test for #74, where a context manager was not able to exit when there were pending messages in the receive queue. - ''' + """ with trio.fail_after(1): - async with open_websocket(HOST, echo_server.port, RESOURCE, - use_ssl=False) as ws: - await ws.send_message('hello') + async with open_websocket( + HOST, echo_server.port, RESOURCE, use_ssl=False + ) as ws: + await ws.send_message("hello") # allow time for the server to respond - await trio.sleep(.1) + await trio.sleep(0.1) @fail_after(DEFAULT_TEST_MAX_DURATION) async def test_max_message_size(nursery): - ''' + """ Set the client's max message size to 100 bytes. The client can send a message larger than 100 bytes, but when it receives a message larger than 100 bytes, it closes the connection with code 1009. - ''' + """ + async def handler(request): - ''' Similar to the echo_request_handler fixture except it runs in a - loop. ''' + """Similar to the echo_request_handler fixture except it runs in a + loop.""" conn = await request.accept() while True: try: @@ -890,16 +960,18 @@ async def handler(request): break server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - max_message_size=100) as client: + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, max_message_size=100 + ) as client: # We can send and receive 100 bytes: - await client.send_message(b'A' * 100) + await client.send_message(b"A" * 100) msg = await client.get_message() assert len(msg) == 100 # We can send 101 bytes but cannot receive 101 bytes: - await client.send_message(b'B' * 101) + await client.send_message(b"B" * 101) with pytest.raises(ConnectionClosed): await client.get_message() assert client.closed @@ -912,19 +984,21 @@ async def test_server_close_client_disconnect_race(nursery, autojump_clock): async def handler(request: WebSocketRequest): ws = await request.accept() ws._for_testing_peer_closed_connection = trio.Event() - await ws.send_message('foo') + await ws.send_message("foo") await ws._for_testing_peer_closed_connection.wait() # with bug, this would raise ConnectionClosed from websocket internal task await trio.aclose_forcefully(ws._stream) server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - connection = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + connection = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) await connection.get_message() await connection.aclose() - await trio.sleep(.1) + await trio.sleep(0.1) async def test_remote_close_local_message_race(nursery, autojump_clock): @@ -944,15 +1018,17 @@ async def handler(request: WebSocketRequest): await ws.aclose() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - client = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + client = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) client._for_testing_peer_closed_connection = trio.Event() - await client.send_message('foo') + await client.send_message("foo") await client._for_testing_peer_closed_connection.wait() with pytest.raises(ConnectionClosed): - await client.send_message('bar') + await client.send_message("bar") async def test_message_after_local_close_race(nursery): @@ -963,10 +1039,12 @@ async def handler(request: WebSocketRequest): await trio.sleep_forever() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - client = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + client = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) orig_send = client._send close_sent = trio.Event() @@ -981,7 +1059,7 @@ async def _send_wrapper(event): await close_sent.wait() assert client.closed with pytest.raises(ConnectionClosed): - await client.send_message('hello') + await client.send_message("hello") @fail_after(DEFAULT_TEST_MAX_DURATION) @@ -999,9 +1077,11 @@ async def handle_connection(request): await trio.sleep_forever() server = await nursery.start( - partial(serve_websocket, handle_connection, HOST, 0, ssl_context=None)) - client = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + partial(serve_websocket, handle_connection, HOST, 0, ssl_context=None) + ) + client = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) # send a CloseConnection event to server but leave client connected await client._send(CloseConnection(code=1000)) await server_stream_closed.wait() diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 9c8c90e..e774229 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -30,13 +30,13 @@ ) import wsproto.utilities -_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.'))) < (0, 22, 0) +_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split("."))) < (0, 22, 0) -CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds +CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds MESSAGE_QUEUE_SIZE = 1 -MAX_MESSAGE_SIZE = 2 ** 20 # 1 MiB -RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB -logger = logging.getLogger('trio-websocket') +MAX_MESSAGE_SIZE = 2**20 # 1 MiB +RECEIVE_BYTES = 4 * 2**10 # 4 KiB +logger = logging.getLogger("trio-websocket") def _ignore_cancel(exc): @@ -53,6 +53,7 @@ class _preserve_current_exception: https://github.com/python-trio/trio/issues/1559 https://gitter.im/python-trio/general?at=5faf2293d37a1a13d6a582cf """ + __slots__ = ("_armed",) def __init__(self): @@ -66,20 +67,33 @@ def __exit__(self, ty, value, tb): return False if _TRIO_MULTI_ERROR: - filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # pylint: disable=no-member + filtered_exception = trio.MultiError.filter( + _ignore_cancel, value + ) # pylint: disable=no-member elif isinstance(value, BaseExceptionGroup): - filtered_exception = value.subgroup(lambda exc: not isinstance(exc, trio.Cancelled)) + filtered_exception = value.subgroup( + lambda exc: not isinstance(exc, trio.Cancelled) + ) else: filtered_exception = _ignore_cancel(value) return filtered_exception is None @asynccontextmanager -async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, +async def open_websocket( + host, + port, + resource, + *, + use_ssl, + subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, - connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): - ''' + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, + connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT, +): + """ Open a WebSocket client connection to a host. This async context manager connects when entering the context manager and @@ -110,15 +124,21 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, :raises HandshakeError: for any networking error, client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. - ''' + """ async with trio.open_nursery() as new_nursery: try: with trio.fail_after(connect_timeout): - connection = await connect_websocket(new_nursery, host, port, - resource, use_ssl=use_ssl, subprotocols=subprotocols, + connection = await connect_websocket( + new_nursery, + host, + port, + resource, + use_ssl=use_ssl, + subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) except trio.TooSlowError: raise ConnectionTimeout from None except OSError as e: @@ -133,10 +153,19 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, raise DisconnectionTimeout from None -async def connect_websocket(nursery, host, port, resource, *, use_ssl, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' +async def connect_websocket( + nursery, + host, + port, + resource, + *, + use_ssl, + subprotocols=None, + extra_headers=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, +): + """ Return an open WebSocket client connection to a host. This function is used to specify a custom nursery to run connection @@ -164,7 +193,7 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection - ''' + """ if use_ssl is True: ssl_context = ssl.create_default_context() elif use_ssl is False: @@ -172,36 +201,52 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, elif isinstance(use_ssl, ssl.SSLContext): ssl_context = use_ssl else: - raise TypeError('`use_ssl` argument must be bool or ssl.SSLContext') - - logger.debug('Connecting to ws%s://%s:%d%s', - '' if ssl_context is None else 's', host, port, resource) + raise TypeError("`use_ssl` argument must be bool or ssl.SSLContext") + + logger.debug( + "Connecting to ws%s://%s:%d%s", + "" if ssl_context is None else "s", + host, + port, + resource, + ) if ssl_context is None: stream = await trio.open_tcp_stream(host, port) else: - stream = await trio.open_ssl_over_tcp_stream(host, port, - ssl_context=ssl_context, https_compatible=True) + stream = await trio.open_ssl_over_tcp_stream( + host, port, ssl_context=ssl_context, https_compatible=True + ) if port in (80, 443): host_header = host else: - host_header = f'{host}:{port}' - connection = WebSocketConnection(stream, + host_header = f"{host}:{port}" + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.CLIENT), host=host_header, path=resource, - client_subprotocols=subprotocols, client_extra_headers=extra_headers, + client_subprotocols=subprotocols, + client_extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection -def open_websocket_url(url, ssl_context=None, *, subprotocols=None, +def open_websocket_url( + url, + ssl_context=None, + *, + subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, - connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): - ''' + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, + connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT, +): + """ Open a WebSocket client connection to a URL. This async context manager connects when entering the context manager and @@ -230,19 +275,33 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None, :raises HandshakeError: for any networking error, client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. - ''' + """ host, port, resource, ssl_context = _url_to_host(url, ssl_context) - return open_websocket(host, port, resource, use_ssl=ssl_context, - subprotocols=subprotocols, extra_headers=extra_headers, + return open_websocket( + host, + port, + resource, + use_ssl=ssl_context, + subprotocols=subprotocols, + extra_headers=extra_headers, message_queue_size=message_queue_size, max_message_size=max_message_size, - connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout) + connect_timeout=connect_timeout, + disconnect_timeout=disconnect_timeout, + ) -async def connect_websocket_url(nursery, url, ssl_context=None, *, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' +async def connect_websocket_url( + nursery, + url, + ssl_context=None, + *, + subprotocols=None, + extra_headers=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, +): + """ Return an open WebSocket client connection to a URL. This function is used to specify a custom nursery to run connection @@ -267,16 +326,23 @@ async def connect_websocket_url(nursery, url, ssl_context=None, *, ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection - ''' + """ host, port, resource, ssl_context = _url_to_host(url, ssl_context) - return await connect_websocket(nursery, host, port, resource, - use_ssl=ssl_context, subprotocols=subprotocols, - extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + return await connect_websocket( + nursery, + host, + port, + resource, + use_ssl=ssl_context, + subprotocols=subprotocols, + extra_headers=extra_headers, + message_queue_size=message_queue_size, + max_message_size=max_message_size, + ) def _url_to_host(url, ssl_context): - ''' + """ Convert a WebSocket URL to a (host,port,resource) tuple. The returned ``ssl_context`` is either the same object that was passed in, @@ -286,15 +352,15 @@ def _url_to_host(url, ssl_context): :param str url: A WebSocket URL. :type ssl_context: ssl.SSLContext or None :returns: A tuple of ``(host, port, resource, ssl_context)``. - ''' + """ url = str(url) # For backward compat with isinstance(url, yarl.URL). parts = urllib.parse.urlsplit(url) - if parts.scheme not in ('ws', 'wss'): + if parts.scheme not in ("ws", "wss"): raise ValueError('WebSocket URL scheme must be "ws:" or "wss:"') if ssl_context is None: - ssl_context = parts.scheme == 'wss' - elif parts.scheme == 'ws': - raise ValueError('SSL context must be None for ws: URL scheme') + ssl_context = parts.scheme == "wss" + elif parts.scheme == "ws": + raise ValueError("SSL context must be None for ws: URL scheme") host = parts.hostname if parts.port is not None: port = parts.port @@ -305,16 +371,24 @@ def _url_to_host(url, ssl_context): # If the target URI's path component is empty, the client MUST # send "/" as the path within the origin-form of request-target. if not path_qs: - path_qs = '/' - if '?' in url: - path_qs += '?' + parts.query + path_qs = "/" + if "?" in url: + path_qs += "?" + parts.query return host, port, path_qs, ssl_context -async def wrap_client_stream(nursery, stream, host, resource, *, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' +async def wrap_client_stream( + nursery, + stream, + host, + resource, + *, + subprotocols=None, + extra_headers=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, +): + """ Wrap an arbitrary stream in a WebSocket connection. This is a low-level function only needed in rare cases. In most cases, you @@ -338,21 +412,29 @@ async def wrap_client_stream(nursery, stream, host, resource, *, ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection - ''' - connection = WebSocketConnection(stream, + """ + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.CLIENT), - host=host, path=resource, - client_subprotocols=subprotocols, client_extra_headers=extra_headers, + host=host, + path=resource, + client_subprotocols=subprotocols, + client_extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection -async def wrap_server_stream(nursery, stream, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' +async def wrap_server_stream( + nursery, + stream, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, +): + """ Wrap an arbitrary stream in a server-side WebSocket. This is a low-level function only needed in rare cases. In most cases, you @@ -367,21 +449,32 @@ async def wrap_server_stream(nursery, stream, then the connection is closed with code 1009 (Message Too Big). :type stream: trio.abc.Stream :rtype: WebSocketRequest - ''' - connection = WebSocketConnection(stream, + """ + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.SERVER), message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) nursery.start_soon(connection._reader_task) request = await connection._get_request() return request -async def serve_websocket(handler, host, port, ssl_context, *, - handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, - disconnect_timeout=CONN_TIMEOUT, task_status=trio.TASK_STATUS_IGNORED): - ''' +async def serve_websocket( + handler, + host, + port, + ssl_context, + *, + handler_nursery=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, + connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT, + task_status=trio.TASK_STATUS_IGNORED, +): + """ Serve a WebSocket over TCP. This function supports the Trio nursery start protocol: ``server = await @@ -415,64 +508,79 @@ async def serve_websocket(handler, host, port, ssl_context, *, to finish the closing handshake before timing out. :param task_status: Part of Trio nursery start protocol. :returns: This function runs until cancelled. - ''' + """ if ssl_context is None: open_tcp_listeners = partial(trio.open_tcp_listeners, port, host=host) else: - open_tcp_listeners = partial(trio.open_ssl_over_tcp_listeners, port, - ssl_context, host=host, https_compatible=True) + open_tcp_listeners = partial( + trio.open_ssl_over_tcp_listeners, + port, + ssl_context, + host=host, + https_compatible=True, + ) listeners = await open_tcp_listeners() - server = WebSocketServer(handler, listeners, - handler_nursery=handler_nursery, message_queue_size=message_queue_size, - max_message_size=max_message_size, connect_timeout=connect_timeout, - disconnect_timeout=disconnect_timeout) + server = WebSocketServer( + handler, + listeners, + handler_nursery=handler_nursery, + message_queue_size=message_queue_size, + max_message_size=max_message_size, + connect_timeout=connect_timeout, + disconnect_timeout=disconnect_timeout, + ) await server.run(task_status=task_status) class HandshakeError(Exception): - ''' + """ There was an error during connection or disconnection with the websocket server. - ''' + """ + class ConnectionTimeout(HandshakeError): - '''There was a timeout when connecting to the websocket server.''' + """There was a timeout when connecting to the websocket server.""" + class DisconnectionTimeout(HandshakeError): - '''There was a timeout when disconnecting from the websocket server.''' + """There was a timeout when disconnecting from the websocket server.""" + class ConnectionClosed(Exception): - ''' + """ A WebSocket operation cannot be completed because the connection is closed or in the process of closing. - ''' + """ + def __init__(self, reason): - ''' + """ Constructor. :param reason: :type reason: CloseReason - ''' + """ super().__init__() self.reason = reason def __repr__(self): - ''' Return representation. ''' - return f'{self.__class__.__name__}<{self.reason}>' + """Return representation.""" + return f"{self.__class__.__name__}<{self.reason}>" class ConnectionRejected(HandshakeError): - ''' + """ A WebSocket connection could not be established because the server rejected the connection attempt. - ''' + """ + def __init__(self, status_code, headers, body): - ''' + """ Constructor. :param reason: :type reason: CloseReason - ''' + """ super().__init__() #: a 3 digit HTTP status code self.status_code = status_code @@ -482,144 +590,149 @@ def __init__(self, status_code, headers, body): self.body = body def __repr__(self): - ''' Return representation. ''' - return f'{self.__class__.__name__}' + """Return representation.""" + return f"{self.__class__.__name__}" class CloseReason: - ''' Contains information about why a WebSocket was closed. ''' + """Contains information about why a WebSocket was closed.""" + def __init__(self, code, reason): - ''' + """ Constructor. :param int code: :param Optional[str] reason: - ''' + """ self._code = code try: self._name = wsframeproto.CloseReason(code).name except ValueError: if 1000 <= code <= 2999: - self._name = 'RFC_RESERVED' + self._name = "RFC_RESERVED" elif 3000 <= code <= 3999: - self._name = 'IANA_RESERVED' + self._name = "IANA_RESERVED" elif 4000 <= code <= 4999: - self._name = 'PRIVATE_RESERVED' + self._name = "PRIVATE_RESERVED" else: - self._name = 'INVALID_CODE' + self._name = "INVALID_CODE" self._reason = reason @property def code(self): - ''' (Read-only) The numeric close code. ''' + """(Read-only) The numeric close code.""" return self._code @property def name(self): - ''' (Read-only) The human-readable close code. ''' + """(Read-only) The human-readable close code.""" return self._name @property def reason(self): - ''' (Read-only) An arbitrary reason string. ''' + """(Read-only) An arbitrary reason string.""" return self._reason def __repr__(self): - ''' Show close code, name, and reason. ''' - return f'{self.__class__.__name__}' \ - f'' + """Show close code, name, and reason.""" + return ( + f"{self.__class__.__name__}" + f"" + ) class Future: - ''' Represents a value that will be available in the future. ''' + """Represents a value that will be available in the future.""" + def __init__(self): - ''' Constructor. ''' + """Constructor.""" self._value = None self._value_event = trio.Event() def set_value(self, value): - ''' + """ Set a value, which will notify any waiters. :param value: - ''' + """ self._value = value self._value_event.set() async def wait_value(self): - ''' + """ Wait for this future to have a value, then return it. :returns: The value set by ``set_value()``. - ''' + """ await self._value_event.wait() return self._value class WebSocketRequest: - ''' + """ Represents a handshake presented by a client to a server. The server may modify the handshake or leave it as is. The server should call ``accept()`` to finish the handshake and obtain a connection object. - ''' + """ + def __init__(self, connection, event): - ''' + """ Constructor. :param WebSocketConnection connection: :type event: wsproto.events.Request - ''' + """ self._connection = connection self._event = event @property def headers(self): - ''' + """ HTTP headers represented as a list of (name, value) pairs. :rtype: list[tuple] - ''' + """ return self._event.extra_headers @property def path(self): - ''' + """ The requested URL path. :rtype: str - ''' + """ return self._event.target @property def proposed_subprotocols(self): - ''' + """ A tuple of protocols proposed by the client. :rtype: tuple[str] - ''' + """ return tuple(self._event.subprotocols) @property def local(self): - ''' + """ The connection's local endpoint. :rtype: Endpoint or str - ''' + """ return self._connection.local @property def remote(self): - ''' + """ The connection's remote endpoint. :rtype: Endpoint or str - ''' + """ return self._connection.remote async def accept(self, *, subprotocol=None, extra_headers=None): - ''' + """ Accept the request and return a connection object. :param subprotocol: The selected subprotocol for this connection. @@ -628,14 +741,14 @@ async def accept(self, *, subprotocol=None, extra_headers=None): send as HTTP headers. :type extra_headers: list[tuple[bytes,bytes]] or None :rtype: WebSocketConnection - ''' + """ if extra_headers is None: extra_headers = [] await self._connection._accept(self._event, subprotocol, extra_headers) return self._connection async def reject(self, status_code, *, extra_headers=None, body=None): - ''' + """ Reject the handshake. :param int status_code: The 3 digit HTTP status code. In order to be @@ -646,14 +759,14 @@ async def reject(self, status_code, *, extra_headers=None, body=None): :param body: If provided, this data will be sent in the response body, otherwise no response body will be sent. :type body: bytes or None - ''' + """ extra_headers = extra_headers or [] - body = body or b'' + body = body or b"" await self._connection._reject(status_code, extra_headers, body) def _get_stream_endpoint(stream, *, local): - ''' + """ Construct an endpoint from a stream. :param trio.Stream stream: @@ -661,7 +774,7 @@ def _get_stream_endpoint(stream, *, local): :returns: An endpoint instance or ``repr()`` for streams that cannot be represented as an endpoint. :rtype: Endpoint or str - ''' + """ socket, is_ssl = None, False if isinstance(stream, trio.SocketStream): socket = stream.socket @@ -677,15 +790,23 @@ def _get_stream_endpoint(stream, *, local): class WebSocketConnection(trio.abc.AsyncResource): - ''' A WebSocket connection. ''' + """A WebSocket connection.""" CONNECTION_ID = itertools.count() - def __init__(self, stream, ws_connection, *, host=None, path=None, - client_subprotocols=None, client_extra_headers=None, + def __init__( + self, + stream, + ws_connection, + *, + host=None, + path=None, + client_subprotocols=None, + client_extra_headers=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE): - ''' + max_message_size=MAX_MESSAGE_SIZE, + ): + """ Constructor. Generally speaking, users are discouraged from directly instantiating a @@ -710,7 +831,7 @@ def __init__(self, stream, ws_connection, *, host=None, path=None, :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). - ''' + """ # NOTE: The implementation uses _close_reason for more than an advisory # purpose. It's critical internal state, indicating when the # connection is closed or closing. @@ -724,9 +845,12 @@ def __init__(self, stream, ws_connection, *, host=None, path=None, self._max_message_size = max_message_size self._reader_running = True if ws_connection.client: - self._initial_request: Optional[Request] = Request(host=host, target=path, + self._initial_request: Optional[Request] = Request( + host=host, + target=path, subprotocols=client_subprotocols, - extra_headers=client_extra_headers or []) + extra_headers=client_extra_headers or [], + ) else: self._initial_request = None self._path = path @@ -734,9 +858,10 @@ def __init__(self, stream, ws_connection, *, host=None, path=None, self._handshake_headers = tuple() self._reject_status = 0 self._reject_headers = tuple() - self._reject_body = b'' + self._reject_body = b"" self._send_channel, self._recv_channel = trio.open_memory_channel( - message_queue_size) + message_queue_size + ) self._pings = OrderedDict() # Set when the server has received a connection request event. This # future is never set on client connections. @@ -753,77 +878,77 @@ def __init__(self, stream, ws_connection, *, host=None, path=None, @property def closed(self): - ''' + """ (Read-only) The reason why the connection was or is being closed, else ``None``. :rtype: Optional[CloseReason] - ''' + """ return self._close_reason @property def is_client(self): - ''' (Read-only) Is this a client instance? ''' + """(Read-only) Is this a client instance?""" return self._wsproto.client @property def is_server(self): - ''' (Read-only) Is this a server instance? ''' + """(Read-only) Is this a server instance?""" return not self._wsproto.client @property def local(self): - ''' + """ The local endpoint of the connection. :rtype: Endpoint or str - ''' + """ return _get_stream_endpoint(self._stream, local=True) @property def remote(self): - ''' + """ The remote endpoint of the connection. :rtype: Endpoint or str - ''' + """ return _get_stream_endpoint(self._stream, local=False) @property def path(self): - ''' + """ The requested URL path. For clients, this is set when the connection is instantiated. For servers, it is set after the handshake completes. :rtype: str - ''' + """ return self._path @property def subprotocol(self): - ''' + """ (Read-only) The negotiated subprotocol, or ``None`` if there is no subprotocol. This is only valid after the opening handshake is complete. :rtype: str or None - ''' + """ return self._subprotocol @property def handshake_headers(self): - ''' + """ The HTTP headers that were sent by the remote during the handshake, stored as 2-tuples containing key/value pairs. Header keys are always lower case. :rtype: tuple[tuple[str,str]] - ''' + """ return self._handshake_headers async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-differ - ''' + """ Close the WebSocket connection. This sends a closing frame and suspends until the connection is closed. @@ -836,7 +961,7 @@ async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-dif :param int code: A 4-digit code number indicating the type of closure. :param str reason: An optional string describing the closure. - ''' + """ with _preserve_current_exception(): await self._aclose(code, reason) @@ -851,8 +976,10 @@ async def _aclose(self, code, reason): # event to peer, while setting the local close reason to normal. self._close_reason = CloseReason(1000, None) await self._send(CloseConnection(code=code, reason=reason)) - elif self._wsproto.state in (ConnectionState.CONNECTING, - ConnectionState.REJECTING): + elif self._wsproto.state in ( + ConnectionState.CONNECTING, + ConnectionState.REJECTING, + ): self._close_handshake.set() # TODO: shouldn't the receive channel be closed earlier, so that # get_message() during send of the CloseConneciton event fails? @@ -867,7 +994,7 @@ async def _aclose(self, code, reason): await self._close_stream() async def get_message(self): - ''' + """ Receive the next WebSocket message. If no message is available immediately, then this function blocks until @@ -882,7 +1009,7 @@ async def get_message(self): :rtype: str or bytes :raises ConnectionClosed: if the connection is closed. - ''' + """ try: message = await self._recv_channel.receive() except (trio.ClosedResourceError, trio.EndOfChannel): @@ -890,7 +1017,7 @@ async def get_message(self): return message async def ping(self, payload=None): - ''' + """ Send WebSocket ping to remote endpoint and wait for a correspoding pong. Each in-flight ping must include a unique payload. This function sends @@ -908,39 +1035,39 @@ async def ping(self, payload=None): :raises ConnectionClosed: if connection is closed. :raises ValueError: if ``payload`` is identical to another in-flight ping. - ''' + """ if self._close_reason: raise ConnectionClosed(self._close_reason) if payload in self._pings: - raise ValueError(f'Payload value {payload} is already in flight.') + raise ValueError(f"Payload value {payload} is already in flight.") if payload is None: - payload = struct.pack('!I', random.getrandbits(32)) + payload = struct.pack("!I", random.getrandbits(32)) event = trio.Event() self._pings[payload] = event await self._send(Ping(payload=payload)) await event.wait() async def pong(self, payload=None): - ''' + """ Send an unsolicted pong. :param payload: The pong's payload. If ``None``, then no payload is sent. :type payload: bytes or None :raises ConnectionClosed: if connection is closed - ''' + """ if self._close_reason: raise ConnectionClosed(self._close_reason) await self._send(Pong(payload=payload)) async def send_message(self, message): - ''' + """ Send a WebSocket message. :param message: The message to send. :type message: str or bytes :raises ConnectionClosed: if connection is closed, or being closed - ''' + """ if self._close_reason: raise ConnectionClosed(self._close_reason) if isinstance(message, str): @@ -948,16 +1075,16 @@ async def send_message(self, message): elif isinstance(message, bytes): event = BytesMessage(data=message) else: - raise ValueError('message must be str or bytes') + raise ValueError("message must be str or bytes") await self._send(event) def __str__(self): - ''' Connection ID and type. ''' - type_ = 'client' if self.is_client else 'server' - return f'{type_}-{self._id}' + """Connection ID and type.""" + type_ = "client" if self.is_client else "server" + return f"{type_}-{self._id}" async def _accept(self, request, subprotocol, extra_headers): - ''' + """ Accept the handshake. This method is only applicable to server-side connections. @@ -967,15 +1094,16 @@ async def _accept(self, request, subprotocol, extra_headers): :type subprotocol: str or None :param list[tuple[bytes,bytes]] extra_headers: A list of 2-tuples containing key/value pairs to send as HTTP headers. - ''' + """ self._subprotocol = subprotocol self._path = request.target - await self._send(AcceptConnection(subprotocol=self._subprotocol, - extra_headers=extra_headers)) + await self._send( + AcceptConnection(subprotocol=self._subprotocol, extra_headers=extra_headers) + ) self._open_handshake.set() async def _reject(self, status_code, headers, body): - ''' + """ Reject the handshake. :param int status_code: The 3 digit HTTP status code. In order to be @@ -984,25 +1112,26 @@ async def _reject(self, status_code, headers, body): :param list[tuple[bytes,bytes]] headers: A list of 2-tuples containing key/value pairs to send as HTTP headers. :param bytes body: An optional response body. - ''' + """ if body: - headers.append(('Content-length', str(len(body)).encode('ascii'))) - reject_conn = RejectConnection(status_code=status_code, headers=headers, - has_body=bool(body)) + headers.append(("Content-length", str(len(body)).encode("ascii"))) + reject_conn = RejectConnection( + status_code=status_code, headers=headers, has_body=bool(body) + ) await self._send(reject_conn) if body: reject_body = RejectData(data=body) await self._send(reject_body) - self._close_reason = CloseReason(1006, 'Rejected WebSocket handshake') + self._close_reason = CloseReason(1006, "Rejected WebSocket handshake") self._close_handshake.set() async def _abort_web_socket(self): - ''' + """ If a stream is closed outside of this class, e.g. due to network conditions or because some other code closed our stream object, then we cannot perform the close handshake. We just need to clean up internal state. - ''' + """ close_reason = wsframeproto.CloseReason.ABNORMAL_CLOSURE if self._wsproto.state == ConnectionState.OPEN: self._wsproto.send(CloseConnection(code=close_reason.value)) @@ -1014,7 +1143,7 @@ async def _abort_web_socket(self): self._close_handshake.set() async def _close_stream(self): - ''' Close the TCP connection. ''' + """Close the TCP connection.""" self._reader_running = False try: with _preserve_current_exception(): @@ -1024,85 +1153,89 @@ async def _close_stream(self): pass async def _close_web_socket(self, code, reason=None): - ''' + """ Mark the WebSocket as closed. Close the message channel so that if any tasks are suspended in get_message(), they will wake up with a ConnectionClosed exception. - ''' + """ self._close_reason = CloseReason(code, reason) exc = ConnectionClosed(self._close_reason) - logger.debug('%s websocket closed %r', self, exc) + logger.debug("%s websocket closed %r", self, exc) await self._send_channel.aclose() async def _get_request(self): - ''' + """ Return a proposal for a WebSocket handshake. This method can only be called on server connections and it may only be called one time. :rtype: WebSocketRequest - ''' + """ if not self.is_server: - raise RuntimeError('This method is only valid for server connections.') + raise RuntimeError("This method is only valid for server connections.") if self._connection_proposal is None: - raise RuntimeError('No proposal available. Did you call this method' - ' multiple times or at the wrong time?') + raise RuntimeError( + "No proposal available. Did you call this method" + " multiple times or at the wrong time?" + ) proposal = await self._connection_proposal.wait_value() self._connection_proposal = None return proposal async def _handle_request_event(self, event): - ''' + """ Handle a connection request. This method is async even though it never awaits, because the event dispatch requires an async function. :param event: - ''' + """ proposal = WebSocketRequest(self, event) self._connection_proposal.set_value(proposal) async def _handle_accept_connection_event(self, event): - ''' + """ Handle an AcceptConnection event. :param wsproto.eventsAcceptConnection event: - ''' + """ self._subprotocol = event.subprotocol self._handshake_headers = tuple(event.extra_headers) self._open_handshake.set() async def _handle_reject_connection_event(self, event): - ''' + """ Handle a RejectConnection event. :param event: - ''' + """ self._reject_status = event.status_code self._reject_headers = tuple(event.headers) if not event.has_body: - raise ConnectionRejected(self._reject_status, self._reject_headers, - body=None) + raise ConnectionRejected( + self._reject_status, self._reject_headers, body=None + ) async def _handle_reject_data_event(self, event): - ''' + """ Handle a RejectData event. :param event: - ''' + """ self._reject_body += event.data if event.body_finished: - raise ConnectionRejected(self._reject_status, self._reject_headers, - body=self._reject_body) + raise ConnectionRejected( + self._reject_status, self._reject_headers, body=self._reject_body + ) async def _handle_close_connection_event(self, event): - ''' + """ Handle a close event. :param wsproto.events.CloseConnection event: - ''' + """ if self._wsproto.state == ConnectionState.REMOTE_CLOSING: # Set _close_reason in advance, so that send_message() will raise # ConnectionClosed during the close handshake. @@ -1119,16 +1252,16 @@ async def _handle_close_connection_event(self, event): await self._close_stream() async def _handle_message_event(self, event): - ''' + """ Handle a message event. :param event: :type event: wsproto.events.BytesMessage or wsproto.events.TextMessage - ''' + """ self._message_size += len(event.data) self._message_parts.append(event.data) if self._message_size > self._max_message_size: - err = f'Exceeded maximum message size: {self._max_message_size} bytes' + err = f"Exceeded maximum message size: {self._max_message_size} bytes" self._message_size = 0 self._message_parts = [] self._close_reason = CloseReason(1009, err) @@ -1136,8 +1269,9 @@ async def _handle_message_event(self, event): await self._recv_channel.aclose() self._reader_running = False elif event.message_finished: - msg = (b'' if isinstance(event, BytesMessage) else '') \ - .join(self._message_parts) + msg = (b"" if isinstance(event, BytesMessage) else "").join( + self._message_parts + ) self._message_size = 0 self._message_parts = [] try: @@ -1149,19 +1283,19 @@ async def _handle_message_event(self, event): pass async def _handle_ping_event(self, event): - ''' + """ Handle a PingReceived event. Wsproto queues a pong frame automatically, so this handler just needs to send it. :param wsproto.events.Ping event: - ''' - logger.debug('%s ping %r', self, event.payload) + """ + logger.debug("%s ping %r", self, event.payload) await self._send(event.response()) async def _handle_pong_event(self, event): - ''' + """ Handle a PongReceived event. When a pong is received, check if we have any ping requests waiting for @@ -1173,7 +1307,7 @@ async def _handle_pong_event(self, event): complicated if some handlers were sync. :param event: - ''' + """ payload = bytes(event.payload) try: event = self._pings[payload] @@ -1183,14 +1317,14 @@ async def _handle_pong_event(self, event): return while self._pings: key, event = self._pings.popitem(0) - skipped = ' [skipped] ' if payload != key else ' ' - logger.debug('%s pong%s%r', self, skipped, key) + skipped = " [skipped] " if payload != key else " " + logger.debug("%s pong%s%r", self, skipped, key) event.set() if payload == key: break async def _reader_task(self): - ''' A background task that reads network data and generates events. ''' + """A background task that reads network data and generates events.""" handlers = { AcceptConnection: self._handle_accept_connection_event, BytesMessage: self._handle_message_event, @@ -1216,12 +1350,12 @@ async def _reader_task(self): event_type = type(event) try: handler = handlers[event_type] - logger.debug('%s received event: %s', self, - event_type) + logger.debug("%s received event: %s", self, event_type) await handler(event) except KeyError: - logger.warning('%s received unknown event type: "%s"', self, - event_type) + logger.warning( + '%s received unknown event type: "%s"', self, event_type + ) except ConnectionClosed: self._reader_running = False break @@ -1233,27 +1367,26 @@ async def _reader_task(self): await self._abort_web_socket() break if len(data) == 0: - logger.debug('%s received zero bytes (connection closed)', - self) + logger.debug("%s received zero bytes (connection closed)", self) # If TCP closed before WebSocket, then record it as an abnormal # closure. if self._wsproto.state != ConnectionState.CLOSED: await self._abort_web_socket() break - logger.debug('%s received %d bytes', self, len(data)) + logger.debug("%s received %d bytes", self, len(data)) if self._wsproto.state != ConnectionState.CLOSED: try: self._wsproto.receive_data(data) except wsproto.utilities.RemoteProtocolError as err: - logger.debug('%s remote protocol error: %s', self, err) + logger.debug("%s remote protocol error: %s", self, err) if err.event_hint: await self._send(err.event_hint) await self._close_stream() - logger.debug('%s reader task finished', self) + logger.debug("%s reader task finished", self) async def _send(self, event): - ''' + """ Send an event to the remote WebSocket. The reader task and one or more writers might try to send messages at @@ -1261,10 +1394,10 @@ async def _send(self, event): requests to send data. :param wsproto.events.Event event: - ''' + """ data = self._wsproto.send(event) async with self._stream_lock: - logger.debug('%s sending %d bytes', self, len(data)) + logger.debug("%s sending %d bytes", self, len(data)) try: await self._stream.send_all(data) except (trio.BrokenResourceError, trio.ClosedResourceError): @@ -1273,7 +1406,8 @@ async def _send(self, event): class Endpoint: - ''' Represents a connection endpoint. ''' + """Represents a connection endpoint.""" + def __init__(self, address, port, is_ssl): #: IP address :class:`ipaddress.ip_address` self.address = ip_address(address) @@ -1284,37 +1418,43 @@ def __init__(self, address, port, is_ssl): @property def url(self): - ''' Return a URL representation of a TCP endpoint, e.g. - ``ws://127.0.0.1:80``. ''' - scheme = 'wss' if self.is_ssl else 'ws' - if (self.port == 80 and not self.is_ssl) or \ - (self.port == 443 and self.is_ssl): - port_str = '' + """Return a URL representation of a TCP endpoint, e.g. + ``ws://127.0.0.1:80``.""" + scheme = "wss" if self.is_ssl else "ws" + if (self.port == 80 and not self.is_ssl) or (self.port == 443 and self.is_ssl): + port_str = "" else: - port_str = ':' + str(self.port) + port_str = ":" + str(self.port) if self.address.version == 4: - return f'{scheme}://{self.address}{port_str}' - return f'{scheme}://[{self.address}]{port_str}' + return f"{scheme}://{self.address}{port_str}" + return f"{scheme}://[{self.address}]{port_str}" def __repr__(self): - ''' Return endpoint info as string. ''' + """Return endpoint info as string.""" return f'Endpoint(address="{self.address}", port={self.port}, is_ssl={self.is_ssl})' class WebSocketServer: - ''' + """ WebSocket server. The server class handles incoming connections on one or more ``Listener`` objects. For each incoming connection, it creates a ``WebSocketConnection`` instance and starts some background tasks, - ''' + """ - def __init__(self, handler, listeners, *, handler_nursery=None, + def __init__( + self, + handler, + listeners, + *, + handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, - disconnect_timeout=CONN_TIMEOUT): - ''' + max_message_size=MAX_MESSAGE_SIZE, + connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT, + ): + """ Constructor. Note that if ``host`` is ``None`` and ``port`` is zero, then you may get @@ -1333,9 +1473,9 @@ def __init__(self, handler, listeners, *, handler_nursery=None, to finish connection handshake before timing out. :param float disconnect_timeout: The number of seconds to wait for a client to finish the closing handshake before timing out. - ''' + """ if len(listeners) == 0: - raise ValueError('Listeners must contain at least one item.') + raise ValueError("Listeners must contain at least one item.") self._handler = handler self._handler_nursery = handler_nursery self._listeners = listeners @@ -1357,24 +1497,27 @@ def port(self): listener must be socket-based. """ if len(self._listeners) > 1: - raise RuntimeError('Cannot get port because this server has' - ' more than 1 listeners.') + raise RuntimeError( + "Cannot get port because this server has" " more than 1 listeners." + ) listener = self.listeners[0] try: return listener.port except AttributeError: - raise RuntimeError(f'This socket does not have a port: {repr(listener)}') from None + raise RuntimeError( + f"This socket does not have a port: {repr(listener)}" + ) from None @property def listeners(self): - ''' + """ Return a list of listener metadata. Each TCP listener is represented as an ``Endpoint`` instance. Other listener types are represented by their ``repr()``. :returns: Listeners :rtype list[Endpoint or str]: - ''' + """ listeners = [] for listener in self._listeners: socket, is_ssl = None, False @@ -1391,7 +1534,7 @@ def listeners(self): return listeners async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): - ''' + """ Start serving incoming connections requests. This method supports the Trio nursery start protocol: ``server = await @@ -1400,30 +1543,34 @@ async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): :param task_status: Part of the Trio nursery start protocol. :returns: This method never returns unless cancelled. - ''' + """ async with trio.open_nursery() as nursery: - serve_listeners = partial(trio.serve_listeners, - self._handle_connection, self._listeners, - handler_nursery=self._handler_nursery) + serve_listeners = partial( + trio.serve_listeners, + self._handle_connection, + self._listeners, + handler_nursery=self._handler_nursery, + ) await nursery.start(serve_listeners) - logger.debug('Listening on %s', - ','.join([str(l) for l in self.listeners])) + logger.debug("Listening on %s", ",".join([str(l) for l in self.listeners])) task_status.started(self) await trio.sleep_forever() async def _handle_connection(self, stream): - ''' + """ Handle an incoming connection by spawning a connection background task and a handler task inside a new nursery. :param stream: :type stream: trio.abc.Stream - ''' + """ async with trio.open_nursery() as nursery: - connection = WebSocketConnection(stream, + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.SERVER), message_queue_size=self._message_queue_size, - max_message_size=self._max_message_size) + max_message_size=self._max_message_size, + ) nursery.start_soon(connection._reader_task) with trio.move_on_after(self._connect_timeout) as connect_scope: request = await connection._get_request() diff --git a/trio_websocket/_version.py b/trio_websocket/_version.py index d1109e1..ce724d8 100644 --- a/trio_websocket/_version.py +++ b/trio_websocket/_version.py @@ -1 +1 @@ -__version__ = '0.11.0-dev' +__version__ = "0.11.0-dev" From 35c82c367770262048de8bbe487273ceb26c7393 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 10 May 2023 14:11:47 -0700 Subject: [PATCH 02/11] Fix formatting on a string constant --- trio_websocket/_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index e774229..89783f0 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -1498,7 +1498,7 @@ def port(self): """ if len(self._listeners) > 1: raise RuntimeError( - "Cannot get port because this server has" " more than 1 listeners." + "Cannot get port because this server has more than 1 listeners." ) listener = self.listeners[0] try: From 65ef37639347b80cfa9bab798343a9f01dfcc769 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 10 May 2023 14:12:02 -0700 Subject: [PATCH 03/11] Work around spurious pylint error trio.MultiError is deprecated, and for technical reasons involving how the deprecation is implemented, this means pylint can't see it and thinks it doesn't exist. It does exist. --- trio_websocket/_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 89783f0..718af7a 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -67,7 +67,7 @@ def __exit__(self, ty, value, tb): return False if _TRIO_MULTI_ERROR: - filtered_exception = trio.MultiError.filter( + filtered_exception = trio.MultiError.filter( # pylint: disable=no-member _ignore_cancel, value ) # pylint: disable=no-member elif isinstance(value, BaseExceptionGroup): From 2970fb5d429b6e9d234bf0920ec8396707239fc7 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 26 Oct 2023 18:38:41 +0200 Subject: [PATCH 04/11] add black to 'make lint', and run it on all files --- Makefile | 1 + autobahn/client.py | 55 ++++++++++++++++++------------ autobahn/server.py | 41 ++++++++++++---------- examples/client.py | 71 ++++++++++++++++++++------------------- examples/generate-cert.py | 19 ++++++----- examples/server.py | 41 ++++++++++++---------- requirements-extras.in | 1 + tests/test_connection.py | 3 +- trio_websocket/_impl.py | 60 ++++++--------------------------- 9 files changed, 138 insertions(+), 154 deletions(-) diff --git a/Makefile b/Makefile index 2efced1..ff64f05 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,7 @@ test: $(PYTHON) -m pytest --cov=trio_websocket --no-cov-on-fail lint: + $(PYTHON) -m black trio_websocket/ tests/ autobahn/ examples/ $(PYTHON) -m pylint trio_websocket/ tests/ autobahn/ examples/ publish: diff --git a/autobahn/client.py b/autobahn/client.py index d93be1c..1537009 100644 --- a/autobahn/client.py +++ b/autobahn/client.py @@ -1,7 +1,7 @@ -''' +""" This test client runs against the Autobahn test server. It is based on the test_client.py in wsproto. -''' +""" import argparse import json import logging @@ -11,28 +11,28 @@ from trio_websocket import open_websocket_url, ConnectionClosed -AGENT = 'trio-websocket' +AGENT = "trio-websocket" MAX_MESSAGE_SIZE = 16 * 1024 * 1024 logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('client') +logger = logging.getLogger("client") async def get_case_count(url): - url = url + '/getCaseCount' + url = url + "/getCaseCount" async with open_websocket_url(url) as conn: case_count = await conn.get_message() - logger.info('Case count=%s', case_count) + logger.info("Case count=%s", case_count) return int(case_count) async def get_case_info(url, case): - url = f'{url}/getCaseInfo?case={case}' + url = f"{url}/getCaseInfo?case={case}" async with open_websocket_url(url) as conn: return json.loads(await conn.get_message()) async def run_case(url, case): - url = f'{url}/runCase?case={case}&agent={AGENT}' + url = f"{url}/runCase?case={case}&agent={AGENT}" try: async with open_websocket_url(url, max_message_size=MAX_MESSAGE_SIZE) as conn: while True: @@ -43,7 +43,7 @@ async def run_case(url, case): async def update_reports(url): - url = f'{url}/updateReports?agent={AGENT}' + url = f"{url}/updateReports?agent={AGENT}" async with open_websocket_url(url) as conn: # This command runs as soon as we connect to it, so we don't need to # send any messages. @@ -51,7 +51,7 @@ async def update_reports(url): async def run_tests(args): - logger = logging.getLogger('trio-websocket') + logger = logging.getLogger("trio-websocket") if args.debug_cases: # Don't fetch case count when debugging a subset of test cases. It adds # noise to the debug logging. @@ -62,7 +62,7 @@ async def run_tests(args): test_cases = list(range(1, case_count + 1)) exception_cases = [] for case in test_cases: - case_id = (await get_case_info(args.url, case))['id'] + case_id = (await get_case_info(args.url, case))["id"] if case_count: logger.info("Running test case %s (%d of %d)", case_id, case, case_count) else: @@ -71,28 +71,39 @@ async def run_tests(args): try: await run_case(args.url, case) except Exception: # pylint: disable=broad-exception-caught - logger.exception(' runtime exception during test case %s (%d)', case_id, case) + logger.exception( + " runtime exception during test case %s (%d)", case_id, case + ) exception_cases.append(case_id) logger.setLevel(logging.INFO) - logger.info('Updating report') + logger.info("Updating report") await update_reports(args.url) if exception_cases: - logger.error('Runtime exception in %d of %d test cases: %s', - len(exception_cases), len(test_cases), exception_cases) + logger.error( + "Runtime exception in %d of %d test cases: %s", + len(exception_cases), + len(test_cases), + exception_cases, + ) sys.exit(1) def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Autobahn client for' - ' trio-websocket') - parser.add_argument('url', help='WebSocket URL for server') + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Autobahn client for" " trio-websocket" + ) + parser.add_argument("url", help="WebSocket URL for server") # TODO: accept case ID's rather than indices - parser.add_argument('debug_cases', type=int, nargs='*', help='Run' - ' individual test cases with debug logging (optional)') + parser.add_argument( + "debug_cases", + type=int, + nargs="*", + help="Run" " individual test cases with debug logging (optional)", + ) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() trio.run(run_tests, args) diff --git a/autobahn/server.py b/autobahn/server.py index ff23846..9941445 100644 --- a/autobahn/server.py +++ b/autobahn/server.py @@ -1,4 +1,4 @@ -''' +""" This simple WebSocket server responds to text messages by reversing each message string and sending it back. @@ -7,34 +7,35 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. -''' +""" import argparse import logging import trio from trio_websocket import serve_websocket, ConnectionClosed, WebSocketRequest -BIND_IP = '0.0.0.0' +BIND_IP = "0.0.0.0" BIND_PORT = 9000 MAX_MESSAGE_SIZE = 16 * 1024 * 1024 logging.basicConfig() -logger = logging.getLogger('client') +logger = logging.getLogger("client") logger.setLevel(logging.INFO) connection_count = 0 async def main(): - ''' Main entry point. ''' - logger.info('Starting websocket server on ws://%s:%d', BIND_IP, BIND_PORT) - await serve_websocket(handler, BIND_IP, BIND_PORT, ssl_context=None, - max_message_size=MAX_MESSAGE_SIZE) + """Main entry point.""" + logger.info("Starting websocket server on ws://%s:%d", BIND_IP, BIND_PORT) + await serve_websocket( + handler, BIND_IP, BIND_PORT, ssl_context=None, max_message_size=MAX_MESSAGE_SIZE + ) async def handler(request: WebSocketRequest): - ''' Reverse incoming websocket messages and send them back. ''' + """Reverse incoming websocket messages and send them back.""" global connection_count # pylint: disable=global-statement connection_count += 1 - logger.info('Connection #%d', connection_count) + logger.info("Connection #%d", connection_count) ws = await request.accept() while True: try: @@ -43,20 +44,24 @@ async def handler(request: WebSocketRequest): except ConnectionClosed: break except Exception: # pylint: disable=broad-exception-caught - logger.exception(' runtime exception handling connection #%d', connection_count) + logger.exception( + " runtime exception handling connection #%d", connection_count + ) def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Autobahn server for' - ' trio-websocket') - parser.add_argument('-d', '--debug', action='store_true', - help='WebSocket URL for server') + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Autobahn server for" " trio-websocket" + ) + parser.add_argument( + "-d", "--debug", action="store_true", help="WebSocket URL for server" + ) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() if args.debug: - logging.getLogger('trio-websocket').setLevel(logging.DEBUG) + logging.getLogger("trio-websocket").setLevel(logging.DEBUG) trio.run(main) diff --git a/examples/client.py b/examples/client.py index 030c12b..c17a830 100644 --- a/examples/client.py +++ b/examples/client.py @@ -1,10 +1,10 @@ -''' +""" This interactive WebSocket client allows the user to send frames to a WebSocket server, including text message, ping, and close frames. To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. -''' +""" import argparse import logging import pathlib @@ -21,49 +21,51 @@ def commands(): - ''' Print the supported commands. ''' - print('Commands: ') - print('send -> send message') - print('ping -> send ping with payload') - print('close [] -> politely close connection with optional reason') + """Print the supported commands.""" + print("Commands: ") + print("send -> send message") + print("ping -> send ping with payload") + print("close [] -> politely close connection with optional reason") print() def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Example trio-websocket client') - parser.add_argument('--heartbeat', action='store_true', - help='Create a heartbeat task') - parser.add_argument('url', help='WebSocket URL to connect to') + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Example trio-websocket client") + parser.add_argument( + "--heartbeat", action="store_true", help="Create a heartbeat task" + ) + parser.add_argument("url", help="WebSocket URL to connect to") return parser.parse_args() async def main(args): - ''' Main entry point, returning False in the case of logged error. ''' - if urllib.parse.urlsplit(args.url).scheme == 'wss': + """Main entry point, returning False in the case of logged error.""" + if urllib.parse.urlsplit(args.url).scheme == "wss": # Configure SSL context to handle our self-signed certificate. Most # clients won't need to do this. try: ssl_context = ssl.create_default_context() - ssl_context.load_verify_locations(here / 'fake.ca.pem') + ssl_context.load_verify_locations(here / "fake.ca.pem") except FileNotFoundError: - logging.error('Did not find file "fake.ca.pem". You need to run' - ' generate-cert.py') + logging.error( + 'Did not find file "fake.ca.pem". You need to run generate-cert.py' + ) return False else: ssl_context = None try: - logging.debug('Connecting to WebSocket…') + logging.debug("Connecting to WebSocket…") async with open_websocket_url(args.url, ssl_context) as conn: await handle_connection(conn, args.heartbeat) except HandshakeError as e: - logging.error('Connection attempt failed: %s', e) + logging.error("Connection attempt failed: %s", e) return False async def handle_connection(ws, use_heartbeat): - ''' Handle the connection. ''' - logging.debug('Connected!') + """Handle the connection.""" + logging.debug("Connected!") try: async with trio.open_nursery() as nursery: if use_heartbeat: @@ -71,12 +73,12 @@ async def handle_connection(ws, use_heartbeat): nursery.start_soon(get_commands, ws) nursery.start_soon(get_messages, ws) except ConnectionClosed as cc: - reason = '' if cc.reason.reason is None else f'"{cc.reason.reason}"' - print(f'Closed: {cc.reason.code}/{cc.reason.name} {reason}') + reason = "" if cc.reason.reason is None else f'"{cc.reason.reason}"' + print(f"Closed: {cc.reason.code}/{cc.reason.name} {reason}") async def heartbeat(ws, timeout, interval): - ''' + """ Send periodic pings on WebSocket ``ws``. Wait up to ``timeout`` seconds to send a ping and receive a pong. Raises @@ -92,7 +94,7 @@ async def heartbeat(ws, timeout, interval): :raises: ``ConnectionClosed`` if ``ws`` is closed. :raises: ``TooSlowError`` if the timeout expires. :returns: This function runs until cancelled. - ''' + """ while True: with trio.fail_after(timeout): await ws.ping() @@ -100,20 +102,19 @@ async def heartbeat(ws, timeout, interval): async def get_commands(ws): - ''' In a loop: get a command from the user and execute it. ''' + """In a loop: get a command from the user and execute it.""" while True: - cmd = await trio.to_thread.run_sync(input, 'cmd> ', - cancellable=True) - if cmd.startswith('ping'): - payload = cmd[5:].encode('utf8') or None + cmd = await trio.to_thread.run_sync(input, "cmd> ", cancellable=True) + if cmd.startswith("ping"): + payload = cmd[5:].encode("utf8") or None await ws.ping(payload) - elif cmd.startswith('send'): + elif cmd.startswith("send"): message = cmd[5:] or None if message is None: logging.error('The "send" command requires a message.') else: await ws.send_message(message) - elif cmd.startswith('close'): + elif cmd.startswith("close"): reason = cmd[6:] or None await ws.aclose(code=1000, reason=reason) break @@ -124,13 +125,13 @@ async def get_commands(ws): async def get_messages(ws): - ''' In a loop: get a WebSocket message and print it out. ''' + """In a loop: get a WebSocket message and print it out.""" while True: message = await ws.get_message() - print(f'message: {message}') + print(f"message: {message}") -if __name__ == '__main__': +if __name__ == "__main__": try: if not trio.run(main, parse_args()): sys.exit(1) diff --git a/examples/generate-cert.py b/examples/generate-cert.py index cc21698..4f0e6ff 100644 --- a/examples/generate-cert.py +++ b/examples/generate-cert.py @@ -3,22 +3,23 @@ import trustme + def main(): here = pathlib.Path(__file__).parent - ca_path = here / 'fake.ca.pem' - server_path = here / 'fake.server.pem' + ca_path = here / "fake.ca.pem" + server_path = here / "fake.server.pem" if ca_path.exists() and server_path.exists(): - print('The CA ceritificate and server certificate already exist.') + print("The CA ceritificate and server certificate already exist.") sys.exit(1) - print('Creating self-signed certificate for localhost/127.0.0.1:') + print("Creating self-signed certificate for localhost/127.0.0.1:") ca_cert = trustme.CA() ca_cert.cert_pem.write_to_path(ca_path) - print(f' * CA certificate: {ca_path}') - server_cert = ca_cert.issue_server_cert('localhost', '127.0.0.1') + print(f" * CA certificate: {ca_path}") + server_cert = ca_cert.issue_server_cert("localhost", "127.0.0.1") server_cert.private_key_and_cert_chain_pem.write_to_path(server_path) - print(f' * Server certificate: {server_path}') - print('Done') + print(f" * Server certificate: {server_path}") + print("Done") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/server.py b/examples/server.py index 611d89b..e77afb0 100644 --- a/examples/server.py +++ b/examples/server.py @@ -1,4 +1,4 @@ -''' +""" This simple WebSocket server responds to text messages by reversing each message string and sending it back. @@ -7,7 +7,7 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. -''' +""" import argparse import logging import pathlib @@ -23,33 +23,38 @@ def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Example trio-websocket client') - parser.add_argument('--ssl', action='store_true', help='Use SSL') - parser.add_argument('host', help='Host interface to bind. If omitted, ' - 'then bind all interfaces.', nargs='?') - parser.add_argument('port', type=int, help='Port to bind.') + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Example trio-websocket client") + parser.add_argument("--ssl", action="store_true", help="Use SSL") + parser.add_argument( + "host", + help="Host interface to bind. If omitted, " "then bind all interfaces.", + nargs="?", + ) + parser.add_argument("port", type=int, help="Port to bind.") return parser.parse_args() async def main(args): - ''' Main entry point. ''' - logging.info('Starting websocket server…') + """Main entry point.""" + logging.info("Starting websocket server…") if args.ssl: ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) try: - ssl_context.load_cert_chain(here / 'fake.server.pem') + ssl_context.load_cert_chain(here / "fake.server.pem") except FileNotFoundError: - logging.error('Did not find file "fake.server.pem". You need to run' - ' generate-cert.py') + logging.error( + 'Did not find file "fake.server.pem". You need to run' + " generate-cert.py" + ) else: ssl_context = None - host = None if args.host == '*' else args.host + host = None if args.host == "*" else args.host await serve_websocket(handler, host, args.port, ssl_context) async def handler(request): - ''' Reverse incoming websocket messages and send them back. ''' + """Reverse incoming websocket messages and send them back.""" logging.info('Handler starting on path "%s"', request.path) ws = await request.accept() while True: @@ -57,12 +62,12 @@ async def handler(request): message = await ws.get_message() await ws.send_message(message[::-1]) except ConnectionClosed: - logging.info('Connection closed') + logging.info("Connection closed") break - logging.info('Handler exiting') + logging.info("Handler exiting") -if __name__ == '__main__': +if __name__ == "__main__": try: trio.run(main, parse_args()) except KeyboardInterrupt: diff --git a/requirements-extras.in b/requirements-extras.in index 9f4d0c5..1abe99b 100644 --- a/requirements-extras.in +++ b/requirements-extras.in @@ -1,4 +1,5 @@ # requirements for `make lint/docs/publish` +black pylint sphinx sphinxcontrib-trio diff --git a/tests/test_connection.py b/tests/test_connection.py index d608375..f101878 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1112,7 +1112,7 @@ async def test_remote_close_rude(): async def client(): client_conn = await wrap_client_stream(nursery, client_stream, HOST, RESOURCE) assert not client_conn.closed - await client_conn.send_message('Hello from client!') + await client_conn.send_message("Hello from client!") with pytest.raises(ConnectionClosed): await client_conn.get_message() @@ -1131,7 +1131,6 @@ async def server(): # pump the messages over memory_stream_pump(server_stream.send_stream, client_stream.receive_stream) - async with trio.open_nursery() as nursery: nursery.start_soon(server) nursery.start_soon(client) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 7cb9865..b577251 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -13,7 +13,6 @@ import trio import trio.abc -from exceptiongroup import BaseExceptionGroup from wsproto import ConnectionType, WSConnection from wsproto.connection import ConnectionState import wsproto.frame_protocol as wsframeproto @@ -30,6 +29,9 @@ ) import wsproto.utilities +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin + _TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split(".")[:2])) < (0, 22) CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds @@ -1344,46 +1346,6 @@ async def _reader_task(self): except ConnectionClosed: self._reader_running = False -<<<<<<< HEAD - while self._reader_running: - # Process events. - for event in self._wsproto.events(): - event_type = type(event) - try: - handler = handlers[event_type] - logger.debug("%s received event: %s", self, event_type) - await handler(event) - except KeyError: - logger.warning( - '%s received unknown event type: "%s"', self, event_type - ) - except ConnectionClosed: - self._reader_running = False - break - - # Get network data. - try: - data = await self._stream.receive_some(RECEIVE_BYTES) - except (trio.BrokenResourceError, trio.ClosedResourceError): - await self._abort_web_socket() - break - if len(data) == 0: - logger.debug("%s received zero bytes (connection closed)", self) - # If TCP closed before WebSocket, then record it as an abnormal - # closure. - if self._wsproto.state != ConnectionState.CLOSED: - await self._abort_web_socket() - break - logger.debug("%s received %d bytes", self, len(data)) - if self._wsproto.state != ConnectionState.CLOSED: - try: - self._wsproto.receive_data(data) - except wsproto.utilities.RemoteProtocolError as err: - logger.debug("%s remote protocol error: %s", self, err) - if err.event_hint: - await self._send(err.event_hint) - await self._close_stream() -======= async with self._send_channel: while self._reader_running: # Process events. @@ -1391,12 +1353,12 @@ async def _reader_task(self): event_type = type(event) try: handler = handlers[event_type] - logger.debug('%s received event: %s', self, - event_type) + logger.debug("%s received event: %s", self, event_type) await handler(event) except KeyError: - logger.warning('%s received unknown event type: "%s"', self, - event_type) + logger.warning( + '%s received unknown event type: "%s"', self, event_type + ) except ConnectionClosed: self._reader_running = False break @@ -1408,23 +1370,21 @@ async def _reader_task(self): await self._abort_web_socket() break if len(data) == 0: - logger.debug('%s received zero bytes (connection closed)', - self) + logger.debug("%s received zero bytes (connection closed)", self) # If TCP closed before WebSocket, then record it as an abnormal # closure. if self._wsproto.state != ConnectionState.CLOSED: await self._abort_web_socket() break - logger.debug('%s received %d bytes', self, len(data)) + logger.debug("%s received %d bytes", self, len(data)) if self._wsproto.state != ConnectionState.CLOSED: try: self._wsproto.receive_data(data) except wsproto.utilities.RemoteProtocolError as err: - logger.debug('%s remote protocol error: %s', self, err) + logger.debug("%s remote protocol error: %s", self, err) if err.event_hint: await self._send(err.event_hint) await self._close_stream() ->>>>>>> origin/HEAD logger.debug("%s reader task finished", self) From 903dfc35a23c64e2f5d03565f48a076e9b6f667c Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 13 Jun 2024 10:42:01 -0500 Subject: [PATCH 05/11] Run new version of black on all files --- autobahn/client.py | 1 + autobahn/server.py | 1 + examples/client.py | 1 + examples/server.py | 1 + setup.py | 52 +++++++++++++++++++------------------- tests/test_connection.py | 3 ++- trio_websocket/_version.py | 2 +- 7 files changed, 33 insertions(+), 28 deletions(-) diff --git a/autobahn/client.py b/autobahn/client.py index 1537009..dc0e890 100644 --- a/autobahn/client.py +++ b/autobahn/client.py @@ -2,6 +2,7 @@ This test client runs against the Autobahn test server. It is based on the test_client.py in wsproto. """ + import argparse import json import logging diff --git a/autobahn/server.py b/autobahn/server.py index 9941445..5263306 100644 --- a/autobahn/server.py +++ b/autobahn/server.py @@ -8,6 +8,7 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. """ + import argparse import logging diff --git a/examples/client.py b/examples/client.py index c17a830..08610cd 100644 --- a/examples/client.py +++ b/examples/client.py @@ -5,6 +5,7 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. """ + import argparse import logging import pathlib diff --git a/examples/server.py b/examples/server.py index e77afb0..0bcca25 100644 --- a/examples/server.py +++ b/examples/server.py @@ -8,6 +8,7 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. """ + import argparse import logging import pathlib diff --git a/setup.py b/setup.py index a5040a6..b743461 100644 --- a/setup.py +++ b/setup.py @@ -10,43 +10,43 @@ # Get description -with (here / 'README.md').open(encoding='utf-8') as f: +with (here / "README.md").open(encoding="utf-8") as f: long_description = f.read() setup( - name='trio-websocket', - version=version['__version__'], - description='WebSocket library for Trio', + name="trio-websocket", + version=version["__version__"], + description="WebSocket library for Trio", long_description=long_description, - long_description_content_type='text/markdown', - url='https://github.com/python-trio/trio-websocket', - author='Mark E. Haase', - author_email='mehaase@gmail.com', + long_description_content_type="text/markdown", + url="https://github.com/python-trio/trio-websocket", + author="Mark E. Haase", + author_email="mehaase@gmail.com", classifiers=[ # See https://pypi.org/classifiers/ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Libraries', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", ], python_requires=">=3.7", - keywords='websocket client server trio', - packages=find_packages(exclude=['docs', 'examples', 'tests']), + keywords="websocket client server trio", + packages=find_packages(exclude=["docs", "examples", "tests"]), install_requires=[ 'exceptiongroup; python_version<"3.11"', - 'trio>=0.11', - 'wsproto>=0.14', + "trio>=0.11", + "wsproto>=0.14", ], project_urls={ - 'Bug Reports': 'https://github.com/python-trio/trio-websocket/issues', - 'Source': 'https://github.com/python-trio/trio-websocket', + "Bug Reports": "https://github.com/python-trio/trio-websocket/issues", + "Source": "https://github.com/python-trio/trio-websocket", }, ) diff --git a/tests/test_connection.py b/tests/test_connection.py index f0e5434..a426109 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -29,6 +29,7 @@ the server to block until the client has sent the closing handshake. In other circumstances """ + from functools import partial, wraps import ssl from unittest.mock import patch @@ -422,7 +423,7 @@ async def handler(request): @fail_after(1) async def test_handshake_server_headers(nursery): async def handler(request): - headers = [('X-Test-Header', 'My test header')] + headers = [("X-Test-Header", "My test header")] server_ws = await request.accept(extra_headers=headers) server = await nursery.start(serve_websocket, handler, HOST, 0, None) diff --git a/trio_websocket/_version.py b/trio_websocket/_version.py index 2320701..5c47800 100644 --- a/trio_websocket/_version.py +++ b/trio_websocket/_version.py @@ -1 +1 @@ -__version__ = '0.12.0-dev' +__version__ = "0.12.0-dev" From d64908b15d35cd061084cb6ec5532cfb1eefc736 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 13 Jun 2024 10:42:39 -0500 Subject: [PATCH 06/11] Run black on docs config with manual spacing fixes --- docs/conf.py | 67 ++++++++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 649051b..88a2596 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,11 +19,12 @@ # -- Project information ----------------------------------------------------- -project = 'Trio WebSocket' -copyright = '2018, Hyperion Gray' -author = 'Hyperion Gray' +project = "Trio WebSocket" +copyright = "2018, Hyperion Gray" +author = "Hyperion Gray" from trio_websocket._version import __version__ as version + release = version @@ -37,22 +38,22 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinxcontrib_trio', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinxcontrib_trio", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -64,7 +65,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None @@ -75,7 +76,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -86,7 +87,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -102,26 +103,22 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'TrioWebSocketdoc' +htmlhelp_basename = "TrioWebSocketdoc" # -- Options for LaTeX output ------------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). - # # 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). # + # The font size ('10pt', '11pt' or '12pt'). # 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. # + # Additional stuff for the LaTeX preamble. # 'preamble': '', - - # Latex figure (float) alignment # + # Latex figure (float) alignment # 'figure_align': 'htbp', } @@ -129,8 +126,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'TrioWebSocket.tex', 'Trio WebSocket Documentation', - 'Hyperion Gray', 'manual'), + ( + master_doc, + "TrioWebSocket.tex", + "Trio WebSocket Documentation", + "Hyperion Gray", + "manual", + ), ] @@ -138,10 +140,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'triowebsocket', 'Trio WebSocket Documentation', - [author], 1) -] +man_pages = [(master_doc, "triowebsocket", "Trio WebSocket Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -150,9 +149,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'TrioWebSocket', 'Trio WebSocket Documentation', - author, 'TrioWebSocket', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "TrioWebSocket", + "Trio WebSocket Documentation", + author, + "TrioWebSocket", + "One line description of project.", + "Miscellaneous", + ), ] @@ -171,10 +176,10 @@ # epub_uid = '' # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # -- Extension configuration ------------------------------------------------- intersphinx_mapping = { - 'trio': ('https://trio.readthedocs.io/en/stable/', None), + "trio": ("https://trio.readthedocs.io/en/stable/", None), } From e7dd16b6ba28628ddb5f9fa941bdb0ff3ace3d3f Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 13 Jun 2024 10:46:34 -0500 Subject: [PATCH 07/11] Fix broken merge commit --- trio_websocket/_impl.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 753746f..7cbbd95 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -69,15 +69,10 @@ def __exit__(self, ty, value, tb): if value is None or not self._armed: return False -<<<<<<< HEAD - if _TRIO_MULTI_ERROR: - filtered_exception = trio.MultiError.filter( # pylint: disable=no-member + if _TRIO_MULTI_ERROR: # pragma: no cover + filtered_exception = trio.MultiError.filter( _ignore_cancel, value ) # pylint: disable=no-member -======= - if _TRIO_MULTI_ERROR: # pragma: no cover - filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # pylint: disable=no-member ->>>>>>> origin/master elif isinstance(value, BaseExceptionGroup): filtered_exception = value.subgroup( lambda exc: not isinstance(exc, trio.Cancelled) From 8fca0593a1d1bebc304e0b78a2c55ad76511971e Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 13 Jun 2024 10:49:57 -0500 Subject: [PATCH 08/11] Add midding black dev dependency --- requirements-dev-full.txt | 2 ++ requirements-dev.in | 1 + 2 files changed, 3 insertions(+) diff --git a/requirements-dev-full.txt b/requirements-dev-full.txt index 6ad3f76..c5b35d0 100644 --- a/requirements-dev-full.txt +++ b/requirements-dev-full.txt @@ -17,6 +17,8 @@ attrs==22.2.0 # trio babel==2.12.1 # via sphinx +black==24.4.2 + # via -r requirements-dev.in bleach==6.0.0 # via readme-renderer build==0.10.0 diff --git a/requirements-dev.in b/requirements-dev.in index 922fb76..30907fd 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -1,5 +1,6 @@ # requirements for `make test` and dependency management attrs>=19.2.0 +black>=24.4.2 pip-tools>=5.5.0 pytest>=4.6 pytest-cov From 32e09cdba5d5f5a2d84c6902b3568c5ba1ea7431 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Fri, 2 Aug 2024 01:36:40 -0500 Subject: [PATCH 09/11] Ignore more pylint issues and re-run black --- trio_websocket/_impl.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index f2dcfc3..f95ed6a 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -72,10 +72,13 @@ def __exit__(self, ty, value, tb): return False if _TRIO_MULTI_ERROR: # pragma: no cover - filtered_exception = trio.MultiError.filter( + filtered_exception = trio.MultiError.filter( # pylint: disable=no-member _ignore_cancel, value ) # pylint: disable=no-member - elif isinstance(value, BaseExceptionGroup): # pylint: disable=possibly-used-before-assignment + elif isinstance( + value, + BaseExceptionGroup # pylint: disable=possibly-used-before-assignment + ): filtered_exception = value.subgroup( lambda exc: not isinstance(exc, trio.Cancelled) ) @@ -92,7 +95,7 @@ async def open_websocket( *, use_ssl: Union[bool, ssl.SSLContext], subprotocols: Optional[Iterable[str]] = None, - extra_headers: Optional[list[tuple[bytes,bytes]]] = None, + extra_headers: Optional[list[tuple[bytes, bytes]]] = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, connect_timeout: float = CONN_TIMEOUT, @@ -861,9 +864,9 @@ def __init__( self._initial_request = None self._path = path self._subprotocol: Optional[str] = None - self._handshake_headers: tuple[tuple[str,str], ...] = tuple() + self._handshake_headers: tuple[tuple[str, str], ...] = tuple() self._reject_status = 0 - self._reject_headers: tuple[tuple[str,str], ...] = tuple() + self._reject_headers: tuple[tuple[str, str], ...] = tuple() self._reject_body = b"" self._send_channel, self._recv_channel = trio.open_memory_channel[ Union[bytes, str] From 6f6da213f62b9341bb308cc858d049ab0b45702d Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Fri, 2 Aug 2024 01:38:39 -0500 Subject: [PATCH 10/11] Re-run `black tests` --- tests/test_connection.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 526ddf7..53f54cf 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -29,6 +29,7 @@ the server to block until the client has sent the closing handshake. In other circumstances """ + from __future__ import annotations from functools import partial, wraps @@ -304,13 +305,20 @@ async def test_client_open_invalid_url(echo_server): async with open_websocket_url("http://foo.com/bar") as conn: pass + async def test_client_open_invalid_ssl(echo_server, nursery): - with pytest.raises(TypeError, match='`use_ssl` argument must be bool or ssl.SSLContext'): + with pytest.raises( + TypeError, match="`use_ssl` argument must be bool or ssl.SSLContext" + ): await connect_websocket(nursery, HOST, echo_server.port, RESOURCE, use_ssl=1) - url = f'ws://{HOST}:{echo_server.port}{RESOURCE}' - with pytest.raises(ValueError, match='^SSL context must be None for ws: URL scheme$' ): - await connect_websocket_url(nursery, url, ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) + url = f"ws://{HOST}:{echo_server.port}{RESOURCE}" + with pytest.raises( + ValueError, match="^SSL context must be None for ws: URL scheme$" + ): + await connect_websocket_url( + nursery, url, ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ) async def test_ascii_encoded_path_is_ok(echo_server): From e43a087dde696de4734bb261ba850dc1dd61409f Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Fri, 2 Aug 2024 01:39:36 -0500 Subject: [PATCH 11/11] Re-run full black again `black trio_websocket/ tests/ autobahn/ examples/` --- trio_websocket/_impl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index f95ed6a..6c012f5 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -76,8 +76,7 @@ def __exit__(self, ty, value, tb): _ignore_cancel, value ) # pylint: disable=no-member elif isinstance( - value, - BaseExceptionGroup # pylint: disable=possibly-used-before-assignment + value, BaseExceptionGroup # pylint: disable=possibly-used-before-assignment ): filtered_exception = value.subgroup( lambda exc: not isinstance(exc, trio.Cancelled)