diff --git a/Makefile b/Makefile index f97c321..e4e3e68 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,7 @@ test: $(PYTHON) -m pytest --cov=trio_websocket --cov-report=term-missing --no-cov-on-fail lint: + $(PYTHON) -m black trio_websocket/ tests/ autobahn/ examples/ $(PYTHON) -m pylint trio_websocket/ tests/ autobahn/ examples/ typecheck: diff --git a/autobahn/client.py b/autobahn/client.py index d93be1c..dc0e890 100644 --- a/autobahn/client.py +++ b/autobahn/client.py @@ -1,7 +1,8 @@ -''' +""" This test client runs against the Autobahn test server. It is based on the test_client.py in wsproto. -''' +""" + import argparse import json import logging @@ -11,28 +12,28 @@ from trio_websocket import open_websocket_url, ConnectionClosed -AGENT = 'trio-websocket' +AGENT = "trio-websocket" MAX_MESSAGE_SIZE = 16 * 1024 * 1024 logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('client') +logger = logging.getLogger("client") async def get_case_count(url): - url = url + '/getCaseCount' + url = url + "/getCaseCount" async with open_websocket_url(url) as conn: case_count = await conn.get_message() - logger.info('Case count=%s', case_count) + logger.info("Case count=%s", case_count) return int(case_count) async def get_case_info(url, case): - url = f'{url}/getCaseInfo?case={case}' + url = f"{url}/getCaseInfo?case={case}" async with open_websocket_url(url) as conn: return json.loads(await conn.get_message()) async def run_case(url, case): - url = f'{url}/runCase?case={case}&agent={AGENT}' + url = f"{url}/runCase?case={case}&agent={AGENT}" try: async with open_websocket_url(url, max_message_size=MAX_MESSAGE_SIZE) as conn: while True: @@ -43,7 +44,7 @@ async def run_case(url, case): async def update_reports(url): - url = f'{url}/updateReports?agent={AGENT}' + url = f"{url}/updateReports?agent={AGENT}" async with open_websocket_url(url) as conn: # This command runs as soon as we connect to it, so we don't need to # send any messages. @@ -51,7 +52,7 @@ async def update_reports(url): async def run_tests(args): - logger = logging.getLogger('trio-websocket') + logger = logging.getLogger("trio-websocket") if args.debug_cases: # Don't fetch case count when debugging a subset of test cases. It adds # noise to the debug logging. @@ -62,7 +63,7 @@ async def run_tests(args): test_cases = list(range(1, case_count + 1)) exception_cases = [] for case in test_cases: - case_id = (await get_case_info(args.url, case))['id'] + case_id = (await get_case_info(args.url, case))["id"] if case_count: logger.info("Running test case %s (%d of %d)", case_id, case, case_count) else: @@ -71,28 +72,39 @@ async def run_tests(args): try: await run_case(args.url, case) except Exception: # pylint: disable=broad-exception-caught - logger.exception(' runtime exception during test case %s (%d)', case_id, case) + logger.exception( + " runtime exception during test case %s (%d)", case_id, case + ) exception_cases.append(case_id) logger.setLevel(logging.INFO) - logger.info('Updating report') + logger.info("Updating report") await update_reports(args.url) if exception_cases: - logger.error('Runtime exception in %d of %d test cases: %s', - len(exception_cases), len(test_cases), exception_cases) + logger.error( + "Runtime exception in %d of %d test cases: %s", + len(exception_cases), + len(test_cases), + exception_cases, + ) sys.exit(1) def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Autobahn client for' - ' trio-websocket') - parser.add_argument('url', help='WebSocket URL for server') + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Autobahn client for" " trio-websocket" + ) + parser.add_argument("url", help="WebSocket URL for server") # TODO: accept case ID's rather than indices - parser.add_argument('debug_cases', type=int, nargs='*', help='Run' - ' individual test cases with debug logging (optional)') + parser.add_argument( + "debug_cases", + type=int, + nargs="*", + help="Run" " individual test cases with debug logging (optional)", + ) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() trio.run(run_tests, args) diff --git a/autobahn/server.py b/autobahn/server.py index ff23846..5263306 100644 --- a/autobahn/server.py +++ b/autobahn/server.py @@ -1,4 +1,4 @@ -''' +""" This simple WebSocket server responds to text messages by reversing each message string and sending it back. @@ -7,34 +7,36 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. -''' +""" + import argparse import logging import trio from trio_websocket import serve_websocket, ConnectionClosed, WebSocketRequest -BIND_IP = '0.0.0.0' +BIND_IP = "0.0.0.0" BIND_PORT = 9000 MAX_MESSAGE_SIZE = 16 * 1024 * 1024 logging.basicConfig() -logger = logging.getLogger('client') +logger = logging.getLogger("client") logger.setLevel(logging.INFO) connection_count = 0 async def main(): - ''' Main entry point. ''' - logger.info('Starting websocket server on ws://%s:%d', BIND_IP, BIND_PORT) - await serve_websocket(handler, BIND_IP, BIND_PORT, ssl_context=None, - max_message_size=MAX_MESSAGE_SIZE) + """Main entry point.""" + logger.info("Starting websocket server on ws://%s:%d", BIND_IP, BIND_PORT) + await serve_websocket( + handler, BIND_IP, BIND_PORT, ssl_context=None, max_message_size=MAX_MESSAGE_SIZE + ) async def handler(request: WebSocketRequest): - ''' Reverse incoming websocket messages and send them back. ''' + """Reverse incoming websocket messages and send them back.""" global connection_count # pylint: disable=global-statement connection_count += 1 - logger.info('Connection #%d', connection_count) + logger.info("Connection #%d", connection_count) ws = await request.accept() while True: try: @@ -43,20 +45,24 @@ async def handler(request: WebSocketRequest): except ConnectionClosed: break except Exception: # pylint: disable=broad-exception-caught - logger.exception(' runtime exception handling connection #%d', connection_count) + logger.exception( + " runtime exception handling connection #%d", connection_count + ) def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Autobahn server for' - ' trio-websocket') - parser.add_argument('-d', '--debug', action='store_true', - help='WebSocket URL for server') + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Autobahn server for" " trio-websocket" + ) + parser.add_argument( + "-d", "--debug", action="store_true", help="WebSocket URL for server" + ) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() if args.debug: - logging.getLogger('trio-websocket').setLevel(logging.DEBUG) + logging.getLogger("trio-websocket").setLevel(logging.DEBUG) trio.run(main) diff --git a/docs/conf.py b/docs/conf.py index 649051b..88a2596 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,11 +19,12 @@ # -- Project information ----------------------------------------------------- -project = 'Trio WebSocket' -copyright = '2018, Hyperion Gray' -author = 'Hyperion Gray' +project = "Trio WebSocket" +copyright = "2018, Hyperion Gray" +author = "Hyperion Gray" from trio_websocket._version import __version__ as version + release = version @@ -37,22 +38,22 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinxcontrib_trio', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinxcontrib_trio", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -64,7 +65,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None @@ -75,7 +76,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -86,7 +87,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -102,26 +103,22 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'TrioWebSocketdoc' +htmlhelp_basename = "TrioWebSocketdoc" # -- Options for LaTeX output ------------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). - # # 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). # + # The font size ('10pt', '11pt' or '12pt'). # 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. # + # Additional stuff for the LaTeX preamble. # 'preamble': '', - - # Latex figure (float) alignment # + # Latex figure (float) alignment # 'figure_align': 'htbp', } @@ -129,8 +126,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'TrioWebSocket.tex', 'Trio WebSocket Documentation', - 'Hyperion Gray', 'manual'), + ( + master_doc, + "TrioWebSocket.tex", + "Trio WebSocket Documentation", + "Hyperion Gray", + "manual", + ), ] @@ -138,10 +140,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'triowebsocket', 'Trio WebSocket Documentation', - [author], 1) -] +man_pages = [(master_doc, "triowebsocket", "Trio WebSocket Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -150,9 +149,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'TrioWebSocket', 'Trio WebSocket Documentation', - author, 'TrioWebSocket', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "TrioWebSocket", + "Trio WebSocket Documentation", + author, + "TrioWebSocket", + "One line description of project.", + "Miscellaneous", + ), ] @@ -171,10 +176,10 @@ # epub_uid = '' # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # -- Extension configuration ------------------------------------------------- intersphinx_mapping = { - 'trio': ('https://trio.readthedocs.io/en/stable/', None), + "trio": ("https://trio.readthedocs.io/en/stable/", None), } diff --git a/examples/client.py b/examples/client.py index 030c12b..08610cd 100644 --- a/examples/client.py +++ b/examples/client.py @@ -1,10 +1,11 @@ -''' +""" This interactive WebSocket client allows the user to send frames to a WebSocket server, including text message, ping, and close frames. To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. -''' +""" + import argparse import logging import pathlib @@ -21,49 +22,51 @@ def commands(): - ''' Print the supported commands. ''' - print('Commands: ') - print('send -> send message') - print('ping -> send ping with payload') - print('close [] -> politely close connection with optional reason') + """Print the supported commands.""" + print("Commands: ") + print("send -> send message") + print("ping -> send ping with payload") + print("close [] -> politely close connection with optional reason") print() def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Example trio-websocket client') - parser.add_argument('--heartbeat', action='store_true', - help='Create a heartbeat task') - parser.add_argument('url', help='WebSocket URL to connect to') + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Example trio-websocket client") + parser.add_argument( + "--heartbeat", action="store_true", help="Create a heartbeat task" + ) + parser.add_argument("url", help="WebSocket URL to connect to") return parser.parse_args() async def main(args): - ''' Main entry point, returning False in the case of logged error. ''' - if urllib.parse.urlsplit(args.url).scheme == 'wss': + """Main entry point, returning False in the case of logged error.""" + if urllib.parse.urlsplit(args.url).scheme == "wss": # Configure SSL context to handle our self-signed certificate. Most # clients won't need to do this. try: ssl_context = ssl.create_default_context() - ssl_context.load_verify_locations(here / 'fake.ca.pem') + ssl_context.load_verify_locations(here / "fake.ca.pem") except FileNotFoundError: - logging.error('Did not find file "fake.ca.pem". You need to run' - ' generate-cert.py') + logging.error( + 'Did not find file "fake.ca.pem". You need to run generate-cert.py' + ) return False else: ssl_context = None try: - logging.debug('Connecting to WebSocket…') + logging.debug("Connecting to WebSocket…") async with open_websocket_url(args.url, ssl_context) as conn: await handle_connection(conn, args.heartbeat) except HandshakeError as e: - logging.error('Connection attempt failed: %s', e) + logging.error("Connection attempt failed: %s", e) return False async def handle_connection(ws, use_heartbeat): - ''' Handle the connection. ''' - logging.debug('Connected!') + """Handle the connection.""" + logging.debug("Connected!") try: async with trio.open_nursery() as nursery: if use_heartbeat: @@ -71,12 +74,12 @@ async def handle_connection(ws, use_heartbeat): nursery.start_soon(get_commands, ws) nursery.start_soon(get_messages, ws) except ConnectionClosed as cc: - reason = '' if cc.reason.reason is None else f'"{cc.reason.reason}"' - print(f'Closed: {cc.reason.code}/{cc.reason.name} {reason}') + reason = "" if cc.reason.reason is None else f'"{cc.reason.reason}"' + print(f"Closed: {cc.reason.code}/{cc.reason.name} {reason}") async def heartbeat(ws, timeout, interval): - ''' + """ Send periodic pings on WebSocket ``ws``. Wait up to ``timeout`` seconds to send a ping and receive a pong. Raises @@ -92,7 +95,7 @@ async def heartbeat(ws, timeout, interval): :raises: ``ConnectionClosed`` if ``ws`` is closed. :raises: ``TooSlowError`` if the timeout expires. :returns: This function runs until cancelled. - ''' + """ while True: with trio.fail_after(timeout): await ws.ping() @@ -100,20 +103,19 @@ async def heartbeat(ws, timeout, interval): async def get_commands(ws): - ''' In a loop: get a command from the user and execute it. ''' + """In a loop: get a command from the user and execute it.""" while True: - cmd = await trio.to_thread.run_sync(input, 'cmd> ', - cancellable=True) - if cmd.startswith('ping'): - payload = cmd[5:].encode('utf8') or None + cmd = await trio.to_thread.run_sync(input, "cmd> ", cancellable=True) + if cmd.startswith("ping"): + payload = cmd[5:].encode("utf8") or None await ws.ping(payload) - elif cmd.startswith('send'): + elif cmd.startswith("send"): message = cmd[5:] or None if message is None: logging.error('The "send" command requires a message.') else: await ws.send_message(message) - elif cmd.startswith('close'): + elif cmd.startswith("close"): reason = cmd[6:] or None await ws.aclose(code=1000, reason=reason) break @@ -124,13 +126,13 @@ async def get_commands(ws): async def get_messages(ws): - ''' In a loop: get a WebSocket message and print it out. ''' + """In a loop: get a WebSocket message and print it out.""" while True: message = await ws.get_message() - print(f'message: {message}') + print(f"message: {message}") -if __name__ == '__main__': +if __name__ == "__main__": try: if not trio.run(main, parse_args()): sys.exit(1) diff --git a/examples/generate-cert.py b/examples/generate-cert.py index cc21698..4f0e6ff 100644 --- a/examples/generate-cert.py +++ b/examples/generate-cert.py @@ -3,22 +3,23 @@ import trustme + def main(): here = pathlib.Path(__file__).parent - ca_path = here / 'fake.ca.pem' - server_path = here / 'fake.server.pem' + ca_path = here / "fake.ca.pem" + server_path = here / "fake.server.pem" if ca_path.exists() and server_path.exists(): - print('The CA ceritificate and server certificate already exist.') + print("The CA ceritificate and server certificate already exist.") sys.exit(1) - print('Creating self-signed certificate for localhost/127.0.0.1:') + print("Creating self-signed certificate for localhost/127.0.0.1:") ca_cert = trustme.CA() ca_cert.cert_pem.write_to_path(ca_path) - print(f' * CA certificate: {ca_path}') - server_cert = ca_cert.issue_server_cert('localhost', '127.0.0.1') + print(f" * CA certificate: {ca_path}") + server_cert = ca_cert.issue_server_cert("localhost", "127.0.0.1") server_cert.private_key_and_cert_chain_pem.write_to_path(server_path) - print(f' * Server certificate: {server_path}') - print('Done') + print(f" * Server certificate: {server_path}") + print("Done") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/server.py b/examples/server.py index 611d89b..0bcca25 100644 --- a/examples/server.py +++ b/examples/server.py @@ -1,4 +1,4 @@ -''' +""" This simple WebSocket server responds to text messages by reversing each message string and sending it back. @@ -7,7 +7,8 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. -''' +""" + import argparse import logging import pathlib @@ -23,33 +24,38 @@ def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Example trio-websocket client') - parser.add_argument('--ssl', action='store_true', help='Use SSL') - parser.add_argument('host', help='Host interface to bind. If omitted, ' - 'then bind all interfaces.', nargs='?') - parser.add_argument('port', type=int, help='Port to bind.') + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Example trio-websocket client") + parser.add_argument("--ssl", action="store_true", help="Use SSL") + parser.add_argument( + "host", + help="Host interface to bind. If omitted, " "then bind all interfaces.", + nargs="?", + ) + parser.add_argument("port", type=int, help="Port to bind.") return parser.parse_args() async def main(args): - ''' Main entry point. ''' - logging.info('Starting websocket server…') + """Main entry point.""" + logging.info("Starting websocket server…") if args.ssl: ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) try: - ssl_context.load_cert_chain(here / 'fake.server.pem') + ssl_context.load_cert_chain(here / "fake.server.pem") except FileNotFoundError: - logging.error('Did not find file "fake.server.pem". You need to run' - ' generate-cert.py') + logging.error( + 'Did not find file "fake.server.pem". You need to run' + " generate-cert.py" + ) else: ssl_context = None - host = None if args.host == '*' else args.host + host = None if args.host == "*" else args.host await serve_websocket(handler, host, args.port, ssl_context) async def handler(request): - ''' Reverse incoming websocket messages and send them back. ''' + """Reverse incoming websocket messages and send them back.""" logging.info('Handler starting on path "%s"', request.path) ws = await request.accept() while True: @@ -57,12 +63,12 @@ async def handler(request): message = await ws.get_message() await ws.send_message(message[::-1]) except ConnectionClosed: - logging.info('Connection closed') + logging.info("Connection closed") break - logging.info('Handler exiting') + logging.info("Handler exiting") -if __name__ == '__main__': +if __name__ == "__main__": try: trio.run(main, parse_args()) except KeyboardInterrupt: diff --git a/requirements-dev-full.txt b/requirements-dev-full.txt index dbc8570..c554717 100644 --- a/requirements-dev-full.txt +++ b/requirements-dev-full.txt @@ -15,6 +15,10 @@ attrs==23.2.0 # trio babel==2.15.0 # via sphinx +black==24.4.2 + # via -r requirements-dev.in +bleach==6.0.0 + # via readme-renderer backports-tarfile==1.2.0 # via jaraco-context build==1.2.1 diff --git a/requirements-dev.in b/requirements-dev.in index 922fb76..30907fd 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -1,5 +1,6 @@ # requirements for `make test` and dependency management attrs>=19.2.0 +black>=24.4.2 pip-tools>=5.5.0 pytest>=4.6 pytest-cov diff --git a/requirements-extras.in b/requirements-extras.in index eb2cd30..10fe1ef 100644 --- a/requirements-extras.in +++ b/requirements-extras.in @@ -1,4 +1,5 @@ # requirements for `make lint/docs/publish` +black mypy pylint sphinx diff --git a/setup.py b/setup.py index 17a21f9..96187d2 100644 --- a/setup.py +++ b/setup.py @@ -10,42 +10,42 @@ # Get description -with (here / 'README.md').open(encoding='utf-8') as f: +with (here / "README.md").open(encoding="utf-8") as f: long_description = f.read() setup( - name='trio-websocket', - version=version['__version__'], - description='WebSocket library for Trio', + name="trio-websocket", + version=version["__version__"], + description="WebSocket library for Trio", long_description=long_description, - long_description_content_type='text/markdown', - url='https://github.com/python-trio/trio-websocket', - author='Mark E. Haase', - author_email='mehaase@gmail.com', + long_description_content_type="text/markdown", + url="https://github.com/python-trio/trio-websocket", + author="Mark E. Haase", + author_email="mehaase@gmail.com", classifiers=[ # See https://pypi.org/classifiers/ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Libraries', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", ], python_requires=">=3.8", - keywords='websocket client server trio', - packages=find_packages(exclude=['docs', 'examples', 'tests']), + keywords="websocket client server trio", + packages=find_packages(exclude=["docs", "examples", "tests"]), install_requires=[ 'exceptiongroup; python_version<"3.11"', - 'trio>=0.11', - 'wsproto>=0.14', + "trio>=0.11", + "wsproto>=0.14", ], project_urls={ - 'Bug Reports': 'https://github.com/python-trio/trio-websocket/issues', - 'Source': 'https://github.com/python-trio/trio-websocket', + "Bug Reports": "https://github.com/python-trio/trio-websocket/issues", + "Source": "https://github.com/python-trio/trio-websocket", }, ) diff --git a/tests/test_connection.py b/tests/test_connection.py index 0837aa5..75bdda4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,4 @@ -''' +""" Unit tests for trio_websocket. Many of these tests involve networking, i.e. real TCP sockets. To maximize @@ -28,7 +28,8 @@ call ``ws.get_message()`` without actually sending it a message. This will cause the server to block until the client has sent the closing handshake. In other circumstances -''' +""" + from __future__ import annotations import copy @@ -74,7 +75,7 @@ WebSocketServer, WebSocketRequest, wrap_client_stream, - wrap_server_stream + wrap_server_stream, ) from trio_websocket._impl import _TRIO_EXC_GROUP_TYPE @@ -82,10 +83,10 @@ if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin -WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.'))) +WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split("."))) -HOST = '127.0.0.1' -RESOURCE = '/resource' +HOST = "127.0.0.1" +RESOURCE = "/resource" DEFAULT_TEST_MAX_DURATION = 1 # Timeout tests follow a general pattern: one side waits TIMEOUT seconds for an @@ -99,27 +100,25 @@ @pytest.fixture async def echo_server(nursery): - ''' A server that reads one message, sends back the same message, - then closes the connection. ''' - serve_fn = partial(serve_websocket, echo_request_handler, HOST, 0, - ssl_context=None) + """A server that reads one message, sends back the same message, + then closes the connection.""" + serve_fn = partial(serve_websocket, echo_request_handler, HOST, 0, ssl_context=None) server = await nursery.start(serve_fn) yield server @pytest.fixture async def echo_conn(echo_server): - ''' Return a client connection instance that is connected to an echo - server. ''' - async with open_websocket(HOST, echo_server.port, RESOURCE, - use_ssl=False) as conn: + """Return a client connection instance that is connected to an echo + server.""" + async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) as conn: yield conn async def echo_request_handler(request): - ''' + """ Accept incoming request and then pass off to echo connection handler. - ''' + """ conn = await request.accept() try: msg = await conn.get_message() @@ -129,8 +128,9 @@ async def echo_request_handler(request): class fail_after: - ''' This decorator fails if the runtime of the decorated function (as - measured by the Trio clock) exceeds the specified value. ''' + """This decorator fails if the runtime of the decorated function (as + measured by the Trio clock) exceeds the specified value.""" + def __init__(self, seconds): self._seconds = seconds @@ -140,7 +140,10 @@ async def wrapper(*args, **kwargs): with trio.move_on_after(self._seconds) as cancel_scope: await fn(*args, **kwargs) if cancel_scope.cancelled_caught: - pytest.fail(f'Test runtime exceeded the maximum {self._seconds} seconds') + pytest.fail( + f"Test runtime exceeded the maximum {self._seconds} seconds" + ) + return wrapper @@ -174,41 +177,41 @@ async def aclose(self): async def test_endpoint_ipv4(): - e1 = Endpoint('10.105.0.2', 80, False) - assert e1.url == 'ws://10.105.0.2' + e1 = Endpoint("10.105.0.2", 80, False) + assert e1.url == "ws://10.105.0.2" assert str(e1) == 'Endpoint(address="10.105.0.2", port=80, is_ssl=False)' - e2 = Endpoint('127.0.0.1', 8000, False) - assert e2.url == 'ws://127.0.0.1:8000' + e2 = Endpoint("127.0.0.1", 8000, False) + assert e2.url == "ws://127.0.0.1:8000" assert str(e2) == 'Endpoint(address="127.0.0.1", port=8000, is_ssl=False)' - e3 = Endpoint('0.0.0.0', 443, True) - assert e3.url == 'wss://0.0.0.0' + e3 = Endpoint("0.0.0.0", 443, True) + assert e3.url == "wss://0.0.0.0" assert str(e3) == 'Endpoint(address="0.0.0.0", port=443, is_ssl=True)' async def test_listen_port_ipv6(): - e1 = Endpoint('2599:8807:6201:b7:16cf:bb9c:a6d3:51ab', 80, False) - assert e1.url == 'ws://[2599:8807:6201:b7:16cf:bb9c:a6d3:51ab]' - assert str(e1) == 'Endpoint(address="2599:8807:6201:b7:16cf:bb9c:a6d3' \ - ':51ab", port=80, is_ssl=False)' - e2 = Endpoint('::1', 8000, False) - assert e2.url == 'ws://[::1]:8000' + e1 = Endpoint("2599:8807:6201:b7:16cf:bb9c:a6d3:51ab", 80, False) + assert e1.url == "ws://[2599:8807:6201:b7:16cf:bb9c:a6d3:51ab]" + assert ( + str(e1) == 'Endpoint(address="2599:8807:6201:b7:16cf:bb9c:a6d3' + ':51ab", port=80, is_ssl=False)' + ) + e2 = Endpoint("::1", 8000, False) + assert e2.url == "ws://[::1]:8000" assert str(e2) == 'Endpoint(address="::1", port=8000, is_ssl=False)' - e3 = Endpoint('::', 443, True) - assert e3.url == 'wss://[::]' + e3 = Endpoint("::", 443, True) + assert e3.url == "wss://[::]" assert str(e3) == 'Endpoint(address="::", port=443, is_ssl=True)' async def test_server_has_listeners(nursery): - server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, - None) + server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, None) assert len(server.listeners) > 0 assert isinstance(server.listeners[0], Endpoint) async def test_serve(nursery): task = current_task() - server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, - None) + server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, None) port = server.port assert server.port != 0 # The server nursery begins with one task (server.listen). @@ -229,11 +232,11 @@ async def test_serve_ssl(nursery): cert = ca.issue_server_cert(HOST) cert.configure_cert(server_context) - server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, - server_context) + server = await nursery.start( + serve_websocket, echo_request_handler, HOST, 0, server_context + ) port = server.port - async with open_websocket(HOST, port, RESOURCE, use_ssl=client_context - ) as conn: + async with open_websocket(HOST, port, RESOURCE, use_ssl=client_context) as conn: assert not conn.closed assert conn.local.is_ssl assert conn.remote.is_ssl @@ -241,8 +244,14 @@ async def test_serve_ssl(nursery): async def test_serve_handler_nursery(nursery): async with trio.open_nursery() as handler_nursery: - serve_with_nursery = partial(serve_websocket, echo_request_handler, - HOST, 0, None, handler_nursery=handler_nursery) + serve_with_nursery = partial( + serve_websocket, + echo_request_handler, + HOST, + 0, + None, + handler_nursery=handler_nursery, + ) server = await nursery.start(serve_with_nursery) port = server.port # The server nursery begins with one task (server.listen). @@ -265,7 +274,7 @@ async def test_serve_non_tcp_listener(nursery): assert len(server.listeners) == 1 with pytest.raises(RuntimeError): server.port # pylint: disable=pointless-statement - assert server.listeners[0].startswith('MemoryListener(') + assert server.listeners[0].startswith("MemoryListener(") async def test_serve_multiple_listeners(nursery): @@ -282,82 +291,92 @@ async def test_serve_multiple_listeners(nursery): assert server.listeners[0].port != 0 # The second listener metadata is a string containing the repr() of a # MemoryListener object. - assert server.listeners[1].startswith('MemoryListener(') + assert server.listeners[1].startswith("MemoryListener(") async def test_client_open(echo_server): - async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) \ - as conn: + async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) as conn: assert not conn.closed assert conn.is_client - assert str(conn).startswith('client-') + assert str(conn).startswith("client-") -@pytest.mark.parametrize('path, expected_path', [ - ('/', '/'), - ('', '/'), - (RESOURCE + '/path', RESOURCE + '/path'), - (RESOURCE + '?foo=bar', RESOURCE + '?foo=bar') -]) +@pytest.mark.parametrize( + "path, expected_path", + [ + ("/", "/"), + ("", "/"), + (RESOURCE + "/path", RESOURCE + "/path"), + (RESOURCE + "?foo=bar", RESOURCE + "?foo=bar"), + ], +) async def test_client_open_url(path, expected_path, echo_server): - url = f'ws://{HOST}:{echo_server.port}{path}' + url = f"ws://{HOST}:{echo_server.port}{path}" async with open_websocket_url(url) as conn: assert conn.path == expected_path async def test_client_open_invalid_url(echo_server): with pytest.raises(ValueError): - async with open_websocket_url('http://foo.com/bar'): + async with open_websocket_url("http://foo.com/bar") as conn: pass + async def test_client_open_invalid_ssl(echo_server, nursery): - with pytest.raises(TypeError, match='`use_ssl` argument must be bool or ssl.SSLContext'): + with pytest.raises( + TypeError, match="`use_ssl` argument must be bool or ssl.SSLContext" + ): await connect_websocket(nursery, HOST, echo_server.port, RESOURCE, use_ssl=1) - url = f'ws://{HOST}:{echo_server.port}{RESOURCE}' - with pytest.raises(ValueError, match='^SSL context must be None for ws: URL scheme$' ): - await connect_websocket_url(nursery, url, ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) + url = f"ws://{HOST}:{echo_server.port}{RESOURCE}" + with pytest.raises( + ValueError, match="^SSL context must be None for ws: URL scheme$" + ): + await connect_websocket_url( + nursery, url, ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ) async def test_ascii_encoded_path_is_ok(echo_server): - path = '%D7%90%D7%91%D7%90?%D7%90%D7%9E%D7%90' - url = f'ws://{HOST}:{echo_server.port}{RESOURCE}/{path}' + path = "%D7%90%D7%91%D7%90?%D7%90%D7%9E%D7%90" + url = f"ws://{HOST}:{echo_server.port}{RESOURCE}/{path}" async with open_websocket_url(url) as conn: - assert conn.path == RESOURCE + '/' + path + assert conn.path == RESOURCE + "/" + path -@patch('trio_websocket._impl.open_websocket') +@patch("trio_websocket._impl.open_websocket") def test_client_open_url_options(open_websocket_mock): """open_websocket_url() must pass its options on to open_websocket()""" port = 1234 - url = f'ws://{HOST}:{port}{RESOURCE}' + url = f"ws://{HOST}:{port}{RESOURCE}" options = { - 'subprotocols': ['chat'], - 'extra_headers': [(b'X-Test-Header', b'My test header')], - 'message_queue_size': 9, - 'max_message_size': 333, - 'connect_timeout': 36, - 'disconnect_timeout': 37, + "subprotocols": ["chat"], + "extra_headers": [(b"X-Test-Header", b"My test header")], + "message_queue_size": 9, + "max_message_size": 333, + "connect_timeout": 36, + "disconnect_timeout": 37, } open_websocket_url(url, **options) _, call_args, call_kwargs = open_websocket_mock.mock_calls[0] assert call_args == (HOST, port, RESOURCE) - assert not call_kwargs.pop('use_ssl') + assert not call_kwargs.pop("use_ssl") assert call_kwargs == options - open_websocket_url(url.replace('ws:', 'wss:')) + open_websocket_url(url.replace("ws:", "wss:")) _, call_args, call_kwargs = open_websocket_mock.mock_calls[1] - assert call_kwargs['use_ssl'] + assert call_kwargs["use_ssl"] async def test_client_connect(echo_server, nursery): - conn = await connect_websocket(nursery, HOST, echo_server.port, RESOURCE, - use_ssl=False) + conn = await connect_websocket( + nursery, HOST, echo_server.port, RESOURCE, use_ssl=False + ) assert not conn.closed async def test_client_connect_url(echo_server, nursery): - url = f'ws://{HOST}:{echo_server.port}{RESOURCE}' + url = f"ws://{HOST}:{echo_server.port}{RESOURCE}" conn = await connect_websocket_url(nursery, url) assert not conn.closed @@ -386,20 +405,21 @@ async def handler(request): await request.accept() server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False) as client_ws: pass async def test_handshake_subprotocol(nursery): async def handler(request): - assert request.proposed_subprotocols == ('chat', 'file') - server_ws = await request.accept(subprotocol='chat') - assert server_ws.subprotocol == 'chat' + assert request.proposed_subprotocols == ("chat", "file") + server_ws = await request.accept(subprotocol="chat") + assert server_ws.subprotocol == "chat" server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - subprotocols=('chat', 'file')) as client_ws: - assert client_ws.subprotocol == 'chat' + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, subprotocols=("chat", "file") + ) as client_ws: + assert client_ws.subprotocol == "chat" async def test_handshake_path(nursery): @@ -409,8 +429,12 @@ async def handler(request): assert server_ws.path == RESOURCE server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - ) as client_ws: + async with open_websocket( + HOST, + server.port, + RESOURCE, + use_ssl=False, + ) as client_ws: assert client_ws.path == RESOURCE @@ -418,32 +442,30 @@ async def handler(request): async def test_handshake_client_headers(nursery): async def handler(request): headers = dict(request.headers) - assert b'x-test-header' in headers - assert headers[b'x-test-header'] == b'My test header' + assert b"x-test-header" in headers + assert headers[b"x-test-header"] == b"My test header" server_ws = await request.accept() - await server_ws.send_message('test') + await server_ws.send_message("test") server = await nursery.start(serve_websocket, handler, HOST, 0, None) - headers = [(b'X-Test-Header', b'My test header')] - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - extra_headers=headers) as client_ws: + headers = [(b"X-Test-Header", b"My test header")] + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, extra_headers=headers + ) as client_ws: await client_ws.get_message() @fail_after(1) async def test_handshake_server_headers(nursery): async def handler(request): - headers = [('X-Test-Header', 'My test header')] - await request.accept(extra_headers=headers) + headers = [("X-Test-Header", "My test header")] + server_ws = await request.accept(extra_headers=headers) server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False - ) as client_ws: + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False) as client_ws: header_key, header_value = client_ws.handshake_headers[0] - assert header_key == b'x-test-header' - assert header_value == b'My test header' - - + assert header_key == b"x-test-header" + assert header_value == b"My test header" @fail_after(5) @@ -452,9 +474,14 @@ async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock): user code also raises exception. Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup """ + async def ki_raising_ping_handler(*args, **kwargs) -> None: raise KeyboardInterrupt - monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler) + + monkeypatch.setattr( + WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler + ) + async def handler(request): server_ws = await request.accept() await server_ws.ping(b"a") @@ -470,6 +497,7 @@ async def handler(request): assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE) assert any(isinstance(e, trio.TooSlowError) for e in e_cause.exceptions) + @fail_after(5) async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock): """_reader_task._handle_ping_event triggers ValueError. @@ -480,10 +508,12 @@ async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock) internal_error.__context__ = TypeError() user_error = NameError() user_error_context = KeyError() + async def raising_ping_event(*args, **kwargs) -> None: raise internal_error monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event) + async def handler(request): server_ws = await request.accept() await server_ws.ping(b"a") @@ -497,26 +527,31 @@ async def handler(request): assert exc_info.value is user_error e_context = exc_info.value.__context__ - assert isinstance(e_context, BaseExceptionGroup) # pylint: disable=possibly-used-before-assignment + assert isinstance( + e_context, + BaseExceptionGroup, # pylint: disable=possibly-used-before-assignment + ) assert internal_error in e_context.exceptions assert user_error_context in e_context.exceptions + @fail_after(5) async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock): """Both user code and _reader_task raise Cancellation. Check that open_websocket reraises the one from user code for traceback reasons. """ - async def sleeping_ping_event(*args, **kwargs) -> None: await trio.sleep_forever() # We monkeypatch WebSocketConnection._handle_ping_event to ensure it will actually # raise Cancelled upon being cancelled. For some reason it doesn't otherwise. monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", sleeping_ping_event) + async def handler(request): server_ws = await request.accept() await server_ws.ping(b"a") + user_cancelled = None user_cancelled_cause = None user_cancelled_context = None @@ -537,25 +572,32 @@ async def handler(request): assert exc_info.value.__cause__ is user_cancelled_cause assert exc_info.value.__context__ is user_cancelled_context + def _trio_default_non_strict_exception_groups() -> bool: - assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme" + assert re.match( + r"^0\.\d\d\.", trio.__version__ + ), "unexpected trio versioning scheme" return int(trio.__version__[2:4]) < 25 + @fail_after(1) async def test_handshake_exception_before_accept() -> None: - ''' In #107, a request handler that throws an exception before finishing the + """In #107, a request handler that throws an exception before finishing the handshake causes the task to hang. The proper behavior is to raise an - exception to the nursery as soon as possible. ''' + exception to the nursery as soon as possible.""" + async def handler(request): raise ValueError() # pylint fails to resolve that BaseExceptionGroup will always be available - with pytest.raises((BaseExceptionGroup, ValueError)) as exc: # pylint: disable=possibly-used-before-assignment + with pytest.raises( + (BaseExceptionGroup, ValueError) + ) as exc: # pylint: disable=possibly-used-before-assignment async with trio.open_nursery() as nursery: - server = await nursery.start(serve_websocket, handler, HOST, 0, - None) - async with open_websocket(HOST, server.port, RESOURCE, - use_ssl=False): + server = await nursery.start(serve_websocket, handler, HOST, 0, None) + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False + ) as client_ws: pass if _trio_default_non_strict_exception_groups(): @@ -566,15 +608,15 @@ async def handler(request): # 2. WebSocketServer.run # 3. trio.serve_listeners # 4. WebSocketServer._handle_connection - assert RaisesGroup( - RaisesGroup( - RaisesGroup( - RaisesGroup(ValueError)))).matches(exc.value) + assert RaisesGroup(RaisesGroup(RaisesGroup(RaisesGroup(ValueError)))).matches( + exc.value + ) async def test_user_exception_cause(nursery) -> None: async def handler(request): await request.accept() + server = await nursery.start(serve_websocket, handler, HOST, 0, None) e_context = TypeError("foo") e_primary = ValueError("bar") @@ -590,62 +632,75 @@ async def handler(request): assert e.__cause__ is e_cause assert e.__context__ is e_context + @fail_after(1) async def test_reject_handshake(nursery): async def handler(request): - body = b'My body' + body = b"My body" await request.reject(400, body=body) server = await nursery.start(serve_websocket, handler, HOST, 0, None) with pytest.raises(ConnectionRejected) as exc_info: - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): + async with open_websocket( + HOST, + server.port, + RESOURCE, + use_ssl=False, + ) as client_ws: pass exc = exc_info.value - assert exc.body == b'My body' + assert exc.body == b"My body" @fail_after(1) async def test_reject_handshake_invalid_info_status(nursery): - ''' + """ An informational status code that is not 101 should cause the client to reject the handshake. Since it is an informational response, there will not be a response body, so this test exercises a different code path. - ''' + """ + async def handler(stream): - await stream.send_all(b'HTTP/1.1 100 CONTINUE\r\n\r\n') + await stream.send_all(b"HTTP/1.1 100 CONTINUE\r\n\r\n") await stream.receive_some(max_bytes=1024) + serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) listeners = await nursery.start(serve_fn) port = listeners[0].socket.getsockname()[1] with pytest.raises(ConnectionRejected) as exc_info: - async with open_websocket(HOST, port, RESOURCE, use_ssl=False): + async with open_websocket( + HOST, + port, + RESOURCE, + use_ssl=False, + ) as client_ws: pass exc = exc_info.value assert exc.status_code == 100 - assert repr(exc) == 'ConnectionRejected' + assert repr(exc) == "ConnectionRejected" assert exc.body is None async def test_handshake_protocol_error(echo_server): - ''' + """ If a client connects to a trio-websocket server and tries to speak HTTP instead of WebSocket, the server should reject the connection. (If the server does not catch the protocol exception, it will raise an exception up to the nursery level and fail the test.) - ''' + """ client_stream = await trio.open_tcp_stream(HOST, echo_server.port) async with client_stream: - await client_stream.send_all(b'GET / HTTP/1.1\r\n\r\n') + await client_stream.send_all(b"GET / HTTP/1.1\r\n\r\n") response = await client_stream.receive_some(1024) - assert response.startswith(b'HTTP/1.1 400') + assert response.startswith(b"HTTP/1.1 400") async def test_client_send_and_receive(echo_conn): async with echo_conn: - await echo_conn.send_message('This is a test message.') + await echo_conn.send_message("This is a test message.") received_msg = await echo_conn.get_message() - assert received_msg == 'This is a test message.' + assert received_msg == "This is a test message." async def test_client_send_invalid_type(echo_conn): @@ -656,17 +711,19 @@ async def test_client_send_invalid_type(echo_conn): async def test_client_ping(echo_conn): async with echo_conn: - await echo_conn.ping(b'A') + await echo_conn.ping(b"A") with pytest.raises(ConnectionClosed): - await echo_conn.ping(b'B') + await echo_conn.ping(b"B") async def test_client_ping_two_payloads(echo_conn): pong_count = 0 + async def ping_and_count(): nonlocal pong_count await echo_conn.ping() pong_count += 1 + async with echo_conn: async with trio.open_nursery() as nursery: nursery.start_soon(ping_and_count) @@ -679,12 +736,14 @@ async def test_client_ping_same_payload(echo_conn): # same time. One of them should succeed and the other should get an # exception. exc_count = 0 + async def ping_and_catch(): nonlocal exc_count try: - await echo_conn.ping(b'A') + await echo_conn.ping(b"A") except ValueError: exc_count += 1 + async with echo_conn: async with trio.open_nursery() as nursery: nursery.start_soon(ping_and_catch) @@ -694,9 +753,9 @@ async def ping_and_catch(): async def test_client_pong(echo_conn): async with echo_conn: - await echo_conn.pong(b'A') + await echo_conn.pong(b"A") with pytest.raises(ConnectionClosed): - await echo_conn.pong(b'B') + await echo_conn.pong(b"B") async def test_client_default_close(echo_conn): @@ -704,16 +763,18 @@ async def test_client_default_close(echo_conn): assert not echo_conn.closed assert echo_conn.closed.code == 1000 assert echo_conn.closed.reason is None - assert repr(echo_conn.closed) == 'CloseReason' + assert ( + repr(echo_conn.closed) == "CloseReason" + ) async def test_client_nondefault_close(echo_conn): async with echo_conn: assert not echo_conn.closed - await echo_conn.aclose(code=1001, reason='test reason') + await echo_conn.aclose(code=1001, reason="test reason") assert echo_conn.closed.code == 1001 - assert echo_conn.closed.reason == 'test reason' + assert echo_conn.closed.reason == "test reason" async def test_wrap_client_stream(nursery): @@ -724,10 +785,10 @@ async def test_wrap_client_stream(nursery): conn = await wrap_client_stream(nursery, stream, HOST, RESOURCE) async with conn: assert not conn.closed - await conn.send_message('Hello from client!') + await conn.send_message("Hello from client!") msg = await conn.get_message() - assert msg == 'Hello from client!' - assert conn.local.startswith('StapledStream(') + assert msg == "Hello from client!" + assert conn.local.startswith("StapledStream(") assert conn.closed @@ -738,38 +799,42 @@ async def handler(stream): async with server_ws: assert not server_ws.closed msg = await server_ws.get_message() - assert msg == 'Hello from client!' + assert msg == "Hello from client!" assert server_ws.closed + serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) listeners = await nursery.start(serve_fn) port = listeners[0].socket.getsockname()[1] async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as client: - await client.send_message('Hello from client!') + await client.send_message("Hello from client!") @fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_client_open_timeout(nursery, autojump_clock): - ''' + """ The client times out waiting for the server to complete the opening handshake. - ''' + """ + async def handler(request): await trio.sleep(FORCE_TIMEOUT) - await request.accept() - pytest.fail('Should not reach this line.') + server_ws = await request.accept() + pytest.fail("Should not reach this line.") server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) with pytest.raises(ConnectionTimeout): - async with open_websocket(HOST, server.port, '/', use_ssl=False, - connect_timeout=TIMEOUT): + async with open_websocket( + HOST, server.port, "/", use_ssl=False, connect_timeout=TIMEOUT + ) as client_ws: pass @fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_client_close_timeout(nursery, autojump_clock): - ''' + """ This client times out waiting for the server to complete the closing handshake. @@ -777,68 +842,83 @@ async def test_client_close_timeout(nursery, autojump_clock): queue size is 0, and the client sends it exactly 1 message. This blocks the server's reader so it won't do the closing handshake for at least ``FORCE_TIMEOUT`` seconds. - ''' + """ + async def handler(request): server_ws = await request.accept() await trio.sleep(FORCE_TIMEOUT) # The next line should raise ConnectionClosed. await server_ws.get_message() - pytest.fail('Should not reach this line.') + pytest.fail("Should not reach this line.") server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None, - message_queue_size=0)) + partial( + serve_websocket, handler, HOST, 0, ssl_context=None, message_queue_size=0 + ) + ) with pytest.raises(DisconnectionTimeout): - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - disconnect_timeout=TIMEOUT) as client_ws: - await client_ws.send_message('test') + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, disconnect_timeout=TIMEOUT + ) as client_ws: + await client_ws.send_message("test") async def test_client_connect_networking_error(): - with patch('trio_websocket._impl.connect_websocket') as \ - connect_websocket_mock: + with patch("trio_websocket._impl.connect_websocket") as connect_websocket_mock: connect_websocket_mock.side_effect = OSError() with pytest.raises(HandshakeError): - async with open_websocket(HOST, 0, '/', use_ssl=False): + async with open_websocket(HOST, 0, "/", use_ssl=False) as client_ws: pass @fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_server_open_timeout(autojump_clock): - ''' + """ The server times out waiting for the client to complete the opening handshake. Server timeouts don't raise exceptions, because handler tasks are launched in an internal nursery and sending exceptions wouldn't be helpful. Instead, timed out tasks silently end. - ''' + """ + async def handler(request): - pytest.fail('This handler should not be called.') + pytest.fail("This handler should not be called.") async with trio.open_nursery() as nursery: - server = await nursery.start(partial(serve_websocket, handler, HOST, 0, - ssl_context=None, handler_nursery=nursery, connect_timeout=TIMEOUT)) + server = await nursery.start( + partial( + serve_websocket, + handler, + HOST, + 0, + ssl_context=None, + handler_nursery=nursery, + connect_timeout=TIMEOUT, + ) + ) old_task_count = len(nursery.child_tasks) # This stream is not a WebSocket, so it won't send a handshake: await trio.open_tcp_stream(HOST, server.port) # Checkpoint so the server's handler task can spawn: await trio.sleep(0) - assert len(nursery.child_tasks) == old_task_count + 1, \ - "Server's reader task did not spawn" + assert ( + len(nursery.child_tasks) == old_task_count + 1 + ), "Server's reader task did not spawn" # Sleep long enough to trigger server's connect_timeout: await trio.sleep(FORCE_TIMEOUT) - assert len(nursery.child_tasks) == old_task_count, \ - "Server's reader task is still running" + assert ( + len(nursery.child_tasks) == old_task_count + ), "Server's reader task is still running" # Cancel the server task: nursery.cancel_scope.cancel() @fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_server_close_timeout(autojump_clock): - ''' + """ The server times out waiting for the client to complete the closing handshake. @@ -849,33 +929,45 @@ async def test_server_close_timeout(autojump_clock): To prevent the client from doing the closing handshake, we make sure that its message queue size is 0 and the server sends it exactly 1 message. This blocks the client's reader and prevents it from doing the client handshake. - ''' + """ + async def handler(request): ws = await request.accept() # Send one message to block the client's reader task: - await ws.send_message('test') + await ws.send_message("test") async with trio.open_nursery() as outer: - server = await outer.start(partial(serve_websocket, handler, HOST, 0, - ssl_context=None, handler_nursery=outer, - disconnect_timeout=TIMEOUT)) + server = await outer.start( + partial( + serve_websocket, + handler, + HOST, + 0, + ssl_context=None, + handler_nursery=outer, + disconnect_timeout=TIMEOUT, + ) + ) old_task_count = len(outer.child_tasks) # Spawn client inside an inner nursery so that we can cancel it's reader # so that it won't do a closing handshake. async with trio.open_nursery() as inner: - await connect_websocket(inner, HOST, server.port, RESOURCE, - use_ssl=False) + ws = await connect_websocket( + inner, HOST, server.port, RESOURCE, use_ssl=False + ) # Checkpoint so the server can spawn a handler task: await trio.sleep(0) - assert len(outer.child_tasks) == old_task_count + 1, \ - "Server's reader task did not spawn" + assert ( + len(outer.child_tasks) == old_task_count + 1 + ), "Server's reader task did not spawn" # The client waits long enough to trigger the server's disconnect # timeout: await trio.sleep(FORCE_TIMEOUT) # The server should have cancelled the handler: - assert len(outer.child_tasks) == old_task_count, \ - "Server's reader task is still running" + assert ( + len(outer.child_tasks) == old_task_count + ), "Server's reader task is still running" # Cancel the client's reader task: inner.cancel_scope.cancel() @@ -888,13 +980,14 @@ async def handler(request): server_ws = await request.accept() with pytest.raises(ConnectionClosed): await server_ws.get_message() + server = await nursery.start(serve_websocket, handler, HOST, 0, None) stream = await trio.open_tcp_stream(HOST, server.port) client_ws = await wrap_client_stream(nursery, stream, HOST, RESOURCE) async with client_ws: await stream.aclose() with pytest.raises(ConnectionClosed): - await client_ws.send_message('Hello from client!') + await client_ws.send_message("Hello from client!") async def test_server_sends_after_close(nursery): @@ -904,7 +997,7 @@ async def handler(request): server_ws = await request.accept() with pytest.raises(ConnectionClosed): while True: - await server_ws.send_message('Hello from server') + await server_ws.send_message("Hello from server") done.set() server = await nursery.start(serve_websocket, handler, HOST, 0, None) @@ -913,7 +1006,7 @@ async def handler(request): async with client_ws: # pump a few messages for x in range(2): - await client_ws.send_message('Hello from client') + await client_ws.send_message("Hello from client") await stream.aclose() await done.wait() @@ -925,7 +1018,8 @@ async def handler(stream): async with server_ws: await stream.aclose() with pytest.raises(ConnectionClosed): - await server_ws.send_message('Hello from client!') + await server_ws.send_message("Hello from client!") + serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) listeners = await nursery.start(serve_fn) port = listeners[0].socket.getsockname()[1] @@ -940,69 +1034,72 @@ async def handler(request): await trio.sleep(1) server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) # connection should close when server handler exits with trio.fail_after(2): - async with open_websocket( - HOST, server.port, '/', use_ssl=False) as connection: + async with open_websocket(HOST, server.port, "/", use_ssl=False) as connection: with pytest.raises(ConnectionClosed) as exc_info: await connection.get_message() exc = exc_info.value - assert exc.reason.name == 'NORMAL_CLOSURE' + assert exc.reason.name == "NORMAL_CLOSURE" @fail_after(DEFAULT_TEST_MAX_DURATION) async def test_read_messages_after_remote_close(nursery): - ''' + """ When the remote endpoint closes, the local endpoint can still read all of the messages sent prior to closing. Any attempt to read beyond that will raise ConnectionClosed. This test also exercises the configuration of the queue size. - ''' + """ server_closed = trio.Event() async def handler(request): server = await request.accept() async with server: - await server.send_message('1') - await server.send_message('2') + await server.send_message("1") + await server.send_message("2") server_closed.set() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) # The client needs a message queue of size 2 so that it can buffer both # incoming messages without blocking the reader task. - async with open_websocket(HOST, server.port, '/', use_ssl=False, - message_queue_size=2) as client: + async with open_websocket( + HOST, server.port, "/", use_ssl=False, message_queue_size=2 + ) as client: await server_closed.wait() - assert await client.get_message() == '1' - assert await client.get_message() == '2' + assert await client.get_message() == "1" + assert await client.get_message() == "2" with pytest.raises(ConnectionClosed): await client.get_message() async def test_no_messages_after_local_close(nursery): - ''' + """ If the local endpoint initiates closing, then pending messages are discarded and any attempt to read a message will raise ConnectionClosed. - ''' + """ client_closed = trio.Event() async def handler(request): # The server sends some messages and then closes. server = await request.accept() async with server: - await server.send_message('1') - await server.send_message('2') + await server.send_message("1") + await server.send_message("2") await client_closed.wait() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - async with open_websocket(HOST, server.port, '/', use_ssl=False) as client: + async with open_websocket(HOST, server.port, "/", use_ssl=False) as client: pass with pytest.raises(ConnectionClosed): await client.get_message() @@ -1010,28 +1107,30 @@ async def handler(request): async def test_cm_exit_with_pending_messages(echo_server, autojump_clock): - ''' + """ Regression test for #74, where a context manager was not able to exit when there were pending messages in the receive queue. - ''' + """ with trio.fail_after(1): - async with open_websocket(HOST, echo_server.port, RESOURCE, - use_ssl=False) as ws: - await ws.send_message('hello') + async with open_websocket( + HOST, echo_server.port, RESOURCE, use_ssl=False + ) as ws: + await ws.send_message("hello") # allow time for the server to respond - await trio.sleep(.1) + await trio.sleep(0.1) @fail_after(DEFAULT_TEST_MAX_DURATION) async def test_max_message_size(nursery): - ''' + """ Set the client's max message size to 100 bytes. The client can send a message larger than 100 bytes, but when it receives a message larger than 100 bytes, it closes the connection with code 1009. - ''' + """ + async def handler(request): - ''' Similar to the echo_request_handler fixture except it runs in a - loop. ''' + """Similar to the echo_request_handler fixture except it runs in a + loop.""" conn = await request.accept() while True: try: @@ -1041,16 +1140,18 @@ async def handler(request): break server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - max_message_size=100) as client: + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, max_message_size=100 + ) as client: # We can send and receive 100 bytes: - await client.send_message(b'A' * 100) + await client.send_message(b"A" * 100) msg = await client.get_message() assert len(msg) == 100 # We can send 101 bytes but cannot receive 101 bytes: - await client.send_message(b'B' * 101) + await client.send_message(b"B" * 101) with pytest.raises(ConnectionClosed): await client.get_message() assert client.closed @@ -1063,19 +1164,21 @@ async def test_server_close_client_disconnect_race(nursery, autojump_clock): async def handler(request: WebSocketRequest): ws = await request.accept() ws._for_testing_peer_closed_connection = trio.Event() - await ws.send_message('foo') + await ws.send_message("foo") await ws._for_testing_peer_closed_connection.wait() # with bug, this would raise ConnectionClosed from websocket internal task await trio.aclose_forcefully(ws._stream) server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - connection = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + connection = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) await connection.get_message() await connection.aclose() - await trio.sleep(.1) + await trio.sleep(0.1) async def test_remote_close_local_message_race(nursery, autojump_clock): @@ -1095,15 +1198,17 @@ async def handler(request: WebSocketRequest): await ws.aclose() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - client = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + client = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) client._for_testing_peer_closed_connection = trio.Event() - await client.send_message('foo') + await client.send_message("foo") await client._for_testing_peer_closed_connection.wait() with pytest.raises(ConnectionClosed): - await client.send_message('bar') + await client.send_message("bar") async def test_message_after_local_close_race(nursery): @@ -1114,10 +1219,12 @@ async def handler(request: WebSocketRequest): await trio.sleep_forever() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - client = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + client = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) orig_send = client._send close_sent = trio.Event() @@ -1132,7 +1239,7 @@ async def _send_wrapper(event): await close_sent.wait() assert client.closed with pytest.raises(ConnectionClosed): - await client.send_message('hello') + await client.send_message("hello") @fail_after(DEFAULT_TEST_MAX_DURATION) @@ -1150,9 +1257,11 @@ async def handle_connection(request): await trio.sleep_forever() server = await nursery.start( - partial(serve_websocket, handle_connection, HOST, 0, ssl_context=None)) - client = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + partial(serve_websocket, handle_connection, HOST, 0, ssl_context=None) + ) + client = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) # send a CloseConnection event to server but leave client connected await client._send(CloseConnection(code=1000)) await server_stream_closed.wait() @@ -1183,7 +1292,7 @@ async def test_remote_close_rude(): async def client(): client_conn = await wrap_client_stream(nursery, client_stream, HOST, RESOURCE) assert not client_conn.closed - await client_conn.send_message('Hello from client!') + await client_conn.send_message("Hello from client!") with pytest.raises(ConnectionClosed): await client_conn.get_message() @@ -1202,7 +1311,6 @@ async def server(): # pump the messages over memory_stream_pump(server_stream.send_stream, client_stream.receive_stream) - async with trio.open_nursery() as nursery: nursery.start_soon(server) nursery.start_soon(client) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 5f3a9d4..26524e8 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -36,18 +36,20 @@ # pylint doesn't care about the version_info check, so need to ignore the warning from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin -_IS_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22) +_IS_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split(".")[:2])) < (0, 22) if _IS_TRIO_MULTI_ERROR: _TRIO_EXC_GROUP_TYPE = trio.MultiError # type: ignore[attr-defined] # pylint: disable=no-member else: - _TRIO_EXC_GROUP_TYPE = BaseExceptionGroup # pylint: disable=possibly-used-before-assignment + _TRIO_EXC_GROUP_TYPE = ( + BaseExceptionGroup # pylint: disable=possibly-used-before-assignment + ) -CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds +CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds MESSAGE_QUEUE_SIZE = 1 -MAX_MESSAGE_SIZE = 2 ** 20 # 1 MiB -RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB -logger = logging.getLogger('trio-websocket') +MAX_MESSAGE_SIZE = 2**20 # 1 MiB +RECEIVE_BYTES = 4 * 2**10 # 4 KiB +logger = logging.getLogger("trio-websocket") class TrioWebsocketInternalError(Exception): @@ -71,6 +73,7 @@ class _preserve_current_exception: https://github.com/python-trio/trio/issues/1559 https://gitter.im/python-trio/general?at=5faf2293d37a1a13d6a582cf """ + __slots__ = ("_armed",) def __init__(self): @@ -84,9 +87,15 @@ def __exit__(self, ty, value, tb): return False if _IS_TRIO_MULTI_ERROR: # pragma: no cover - filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # pylint: disable=no-member - elif isinstance(value, BaseExceptionGroup): # pylint: disable=possibly-used-before-assignment - filtered_exception = value.subgroup(lambda exc: not isinstance(exc, trio.Cancelled)) + filtered_exception = trio.MultiError.filter( # pylint: disable=no-member + _ignore_cancel, value + ) + elif isinstance( + value, BaseExceptionGroup + ): # pylint: disable=possibly-used-before-assignment + filtered_exception = value.subgroup( + lambda exc: not isinstance(exc, trio.Cancelled) + ) else: filtered_exception = _ignore_cancel(value) return filtered_exception is None @@ -94,19 +103,19 @@ def __exit__(self, ty, value, tb): @asynccontextmanager async def open_websocket( - host: str, - port: int, - resource: str, - *, - use_ssl: Union[bool, ssl.SSLContext], - subprotocols: Optional[Iterable[str]] = None, - extra_headers: Optional[list[tuple[bytes,bytes]]] = None, - message_queue_size: int = MESSAGE_QUEUE_SIZE, - max_message_size: int = MAX_MESSAGE_SIZE, - connect_timeout: float = CONN_TIMEOUT, - disconnect_timeout: float = CONN_TIMEOUT - ): - ''' + host: str, + port: int, + resource: str, + *, + use_ssl: Union[bool, ssl.SSLContext], + subprotocols: Optional[Iterable[str]] = None, + extra_headers: Optional[list[tuple[bytes, bytes]]] = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + connect_timeout: float = CONN_TIMEOUT, + disconnect_timeout: float = CONN_TIMEOUT, +): + """ Open a WebSocket client connection to a host. This async context manager connects when entering the context manager and @@ -137,7 +146,7 @@ async def open_websocket( :raises HandshakeError: for any networking error, client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. - ''' + """ # This context manager tries very very hard not to raise an exceptiongroup # in order to be as transparent as possible for the end user. @@ -160,15 +169,20 @@ async def open_websocket( # exception in the last `finally`. If we encountered exceptions in user code # or in reader task then they will be set as the `__context__`. - async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection: try: with trio.fail_after(connect_timeout): - return await connect_websocket(nursery, host, port, - resource, use_ssl=use_ssl, subprotocols=subprotocols, + return await connect_websocket( + nursery, + host, + port, + resource, + use_ssl=use_ssl, + subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) except trio.TooSlowError: raise ConnectionTimeout from None except OSError as e: @@ -192,7 +206,7 @@ def _raise(exc: BaseException) -> NoReturn: exc.__context__ = context del exc, context - connection: WebSocketConnection|None=None + connection: WebSocketConnection | None = None close_result: outcome.Maybe[None] | None = None user_error = None @@ -225,7 +239,7 @@ def _raise(exc: BaseException) -> NoReturn: _raise(e.exceptions[0]) # contains at most 1 non-cancelled exceptions - exception_to_raise: BaseException|None = None + exception_to_raise: BaseException | None = None for sub_exc in e.exceptions: if not isinstance(sub_exc, trio.Cancelled): if exception_to_raise is not None: @@ -255,13 +269,16 @@ def _raise(exc: BaseException) -> NoReturn: # and, if not None, `user_error.__context__` if user_error is not None: exceptions = [subexc for subexc in e.exceptions if subexc is not user_error] - eg_substr = '' + eg_substr = "" # there's technically loss of info here, with __suppress_context__=True you # still have original __context__ available, just not printed. But we delete # it completely because we can't partially suppress the group - if user_error.__context__ is not None and not user_error.__suppress_context__: + if ( + user_error.__context__ is not None + and not user_error.__suppress_context__ + ): exceptions.append(user_error.__context__) - eg_substr = ' and the context for the user exception' + eg_substr = " and the context for the user exception" eg_str = ( "Both internal and user exceptions encountered. This group contains " "the internal exception(s)" + eg_substr + "." @@ -280,17 +297,24 @@ def _raise(exc: BaseException) -> NoReturn: if close_result is not None: close_result.unwrap() - # error setting up, unwrap that exception if connection is None: result.unwrap() -async def connect_websocket(nursery, host, port, resource, *, use_ssl, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE - ) -> WebSocketConnection: - ''' +async def connect_websocket( + nursery, + host, + port, + resource, + *, + use_ssl, + subprotocols=None, + extra_headers=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, +) -> WebSocketConnection: + """ Return an open WebSocket client connection to a host. This function is used to specify a custom nursery to run connection @@ -318,7 +342,7 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection - ''' + """ if use_ssl is True: ssl_context = ssl.create_default_context() elif use_ssl is False: @@ -326,37 +350,53 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, elif isinstance(use_ssl, ssl.SSLContext): ssl_context = use_ssl else: - raise TypeError('`use_ssl` argument must be bool or ssl.SSLContext') - - logger.debug('Connecting to ws%s://%s:%d%s', - '' if ssl_context is None else 's', host, port, resource) + raise TypeError("`use_ssl` argument must be bool or ssl.SSLContext") + + logger.debug( + "Connecting to ws%s://%s:%d%s", + "" if ssl_context is None else "s", + host, + port, + resource, + ) stream: trio.SSLStream[trio.SocketStream] | trio.SocketStream if ssl_context is None: stream = await trio.open_tcp_stream(host, port) else: - stream = await trio.open_ssl_over_tcp_stream(host, port, - ssl_context=ssl_context, https_compatible=True) + stream = await trio.open_ssl_over_tcp_stream( + host, port, ssl_context=ssl_context, https_compatible=True + ) if port in (80, 443): host_header = host else: - host_header = f'{host}:{port}' - connection = WebSocketConnection(stream, + host_header = f"{host}:{port}" + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.CLIENT), host=host_header, path=resource, - client_subprotocols=subprotocols, client_extra_headers=extra_headers, + client_subprotocols=subprotocols, + client_extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection -def open_websocket_url(url, ssl_context=None, *, subprotocols=None, +def open_websocket_url( + url, + ssl_context=None, + *, + subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, - connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): - ''' + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, + connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT, +): + """ Open a WebSocket client connection to a URL. This async context manager connects when entering the context manager and @@ -385,19 +425,33 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None, :raises HandshakeError: for any networking error, client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. - ''' + """ host, port, resource, ssl_context = _url_to_host(url, ssl_context) - return open_websocket(host, port, resource, use_ssl=ssl_context, - subprotocols=subprotocols, extra_headers=extra_headers, + return open_websocket( + host, + port, + resource, + use_ssl=ssl_context, + subprotocols=subprotocols, + extra_headers=extra_headers, message_queue_size=message_queue_size, max_message_size=max_message_size, - connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout) + connect_timeout=connect_timeout, + disconnect_timeout=disconnect_timeout, + ) -async def connect_websocket_url(nursery, url, ssl_context=None, *, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' +async def connect_websocket_url( + nursery, + url, + ssl_context=None, + *, + subprotocols=None, + extra_headers=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, +): + """ Return an open WebSocket client connection to a URL. This function is used to specify a custom nursery to run connection @@ -422,16 +476,23 @@ async def connect_websocket_url(nursery, url, ssl_context=None, *, ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection - ''' + """ host, port, resource, ssl_context = _url_to_host(url, ssl_context) - return await connect_websocket(nursery, host, port, resource, - use_ssl=ssl_context, subprotocols=subprotocols, - extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + return await connect_websocket( + nursery, + host, + port, + resource, + use_ssl=ssl_context, + subprotocols=subprotocols, + extra_headers=extra_headers, + message_queue_size=message_queue_size, + max_message_size=max_message_size, + ) def _url_to_host(url, ssl_context): - ''' + """ Convert a WebSocket URL to a (host,port,resource) tuple. The returned ``ssl_context`` is either the same object that was passed in, @@ -441,15 +502,15 @@ def _url_to_host(url, ssl_context): :param str url: A WebSocket URL. :type ssl_context: ssl.SSLContext or None :returns: A tuple of ``(host, port, resource, ssl_context)``. - ''' + """ url = str(url) # For backward compat with isinstance(url, yarl.URL). parts = urllib.parse.urlsplit(url) - if parts.scheme not in ('ws', 'wss'): + if parts.scheme not in ("ws", "wss"): raise ValueError('WebSocket URL scheme must be "ws:" or "wss:"') if ssl_context is None: - ssl_context = parts.scheme == 'wss' - elif parts.scheme == 'ws': - raise ValueError('SSL context must be None for ws: URL scheme') + ssl_context = parts.scheme == "wss" + elif parts.scheme == "ws": + raise ValueError("SSL context must be None for ws: URL scheme") host = parts.hostname if parts.port is not None: port = parts.port @@ -460,16 +521,24 @@ def _url_to_host(url, ssl_context): # If the target URI's path component is empty, the client MUST # send "/" as the path within the origin-form of request-target. if not path_qs: - path_qs = '/' - if '?' in url: - path_qs += '?' + parts.query + path_qs = "/" + if "?" in url: + path_qs += "?" + parts.query return host, port, path_qs, ssl_context -async def wrap_client_stream(nursery, stream, host, resource, *, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' +async def wrap_client_stream( + nursery, + stream, + host, + resource, + *, + subprotocols=None, + extra_headers=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, +): + """ Wrap an arbitrary stream in a WebSocket connection. This is a low-level function only needed in rare cases. In most cases, you @@ -493,21 +562,29 @@ async def wrap_client_stream(nursery, stream, host, resource, *, ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection - ''' - connection = WebSocketConnection(stream, + """ + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.CLIENT), - host=host, path=resource, - client_subprotocols=subprotocols, client_extra_headers=extra_headers, + host=host, + path=resource, + client_subprotocols=subprotocols, + client_extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection -async def wrap_server_stream(nursery, stream, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' +async def wrap_server_stream( + nursery, + stream, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, +): + """ Wrap an arbitrary stream in a server-side WebSocket. This is a low-level function only needed in rare cases. In most cases, you @@ -522,21 +599,32 @@ async def wrap_server_stream(nursery, stream, then the connection is closed with code 1009 (Message Too Big). :type stream: trio.abc.Stream :rtype: WebSocketRequest - ''' - connection = WebSocketConnection(stream, + """ + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.SERVER), message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) nursery.start_soon(connection._reader_task) request = await connection._get_request() return request -async def serve_websocket(handler, host, port, ssl_context, *, - handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, - disconnect_timeout=CONN_TIMEOUT, task_status=trio.TASK_STATUS_IGNORED): - ''' +async def serve_websocket( + handler, + host, + port, + ssl_context, + *, + handler_nursery=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, + connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT, + task_status=trio.TASK_STATUS_IGNORED, +): + """ Serve a WebSocket over TCP. This function supports the Trio nursery start protocol: ``server = await @@ -570,64 +658,79 @@ async def serve_websocket(handler, host, port, ssl_context, *, to finish the closing handshake before timing out. :param task_status: Part of Trio nursery start protocol. :returns: This function runs until cancelled. - ''' + """ if ssl_context is None: open_tcp_listeners = partial(trio.open_tcp_listeners, port, host=host) else: - open_tcp_listeners = partial(trio.open_ssl_over_tcp_listeners, port, - ssl_context, host=host, https_compatible=True) + open_tcp_listeners = partial( + trio.open_ssl_over_tcp_listeners, + port, + ssl_context, + host=host, + https_compatible=True, + ) listeners = await open_tcp_listeners() - server = WebSocketServer(handler, listeners, - handler_nursery=handler_nursery, message_queue_size=message_queue_size, - max_message_size=max_message_size, connect_timeout=connect_timeout, - disconnect_timeout=disconnect_timeout) + server = WebSocketServer( + handler, + listeners, + handler_nursery=handler_nursery, + message_queue_size=message_queue_size, + max_message_size=max_message_size, + connect_timeout=connect_timeout, + disconnect_timeout=disconnect_timeout, + ) await server.run(task_status=task_status) class HandshakeError(Exception): - ''' + """ There was an error during connection or disconnection with the websocket server. - ''' + """ + class ConnectionTimeout(HandshakeError): - '''There was a timeout when connecting to the websocket server.''' + """There was a timeout when connecting to the websocket server.""" + class DisconnectionTimeout(HandshakeError): - '''There was a timeout when disconnecting from the websocket server.''' + """There was a timeout when disconnecting from the websocket server.""" + class ConnectionClosed(Exception): - ''' + """ A WebSocket operation cannot be completed because the connection is closed or in the process of closing. - ''' + """ + def __init__(self, reason): - ''' + """ Constructor. :param reason: :type reason: CloseReason - ''' + """ super().__init__(reason) self.reason = reason def __repr__(self): - ''' Return representation. ''' - return f'{self.__class__.__name__}<{self.reason}>' + """Return representation.""" + return f"{self.__class__.__name__}<{self.reason}>" class ConnectionRejected(HandshakeError): - ''' + """ A WebSocket connection could not be established because the server rejected the connection attempt. - ''' + """ + def __init__(self, status_code, headers, body): - ''' + """ Constructor. :param reason: :type reason: CloseReason - ''' + """ super().__init__(status_code, headers, body) #: a 3 digit HTTP status code self.status_code = status_code @@ -637,144 +740,149 @@ def __init__(self, status_code, headers, body): self.body = body def __repr__(self): - ''' Return representation. ''' - return f'{self.__class__.__name__}' + """Return representation.""" + return f"{self.__class__.__name__}" class CloseReason: - ''' Contains information about why a WebSocket was closed. ''' + """Contains information about why a WebSocket was closed.""" + def __init__(self, code, reason): - ''' + """ Constructor. :param int code: :param Optional[str] reason: - ''' + """ self._code = code try: self._name = wsframeproto.CloseReason(code).name except ValueError: if 1000 <= code <= 2999: - self._name = 'RFC_RESERVED' + self._name = "RFC_RESERVED" elif 3000 <= code <= 3999: - self._name = 'IANA_RESERVED' + self._name = "IANA_RESERVED" elif 4000 <= code <= 4999: - self._name = 'PRIVATE_RESERVED' + self._name = "PRIVATE_RESERVED" else: - self._name = 'INVALID_CODE' + self._name = "INVALID_CODE" self._reason = reason @property def code(self): - ''' (Read-only) The numeric close code. ''' + """(Read-only) The numeric close code.""" return self._code @property def name(self): - ''' (Read-only) The human-readable close code. ''' + """(Read-only) The human-readable close code.""" return self._name @property def reason(self): - ''' (Read-only) An arbitrary reason string. ''' + """(Read-only) An arbitrary reason string.""" return self._reason def __repr__(self): - ''' Show close code, name, and reason. ''' - return f'{self.__class__.__name__}' \ - f'' + """Show close code, name, and reason.""" + return ( + f"{self.__class__.__name__}" + f"" + ) class Future: - ''' Represents a value that will be available in the future. ''' + """Represents a value that will be available in the future.""" + def __init__(self): - ''' Constructor. ''' + """Constructor.""" self._value = None self._value_event = trio.Event() def set_value(self, value): - ''' + """ Set a value, which will notify any waiters. :param value: - ''' + """ self._value = value self._value_event.set() async def wait_value(self): - ''' + """ Wait for this future to have a value, then return it. :returns: The value set by ``set_value()``. - ''' + """ await self._value_event.wait() return self._value class WebSocketRequest: - ''' + """ Represents a handshake presented by a client to a server. The server may modify the handshake or leave it as is. The server should call ``accept()`` to finish the handshake and obtain a connection object. - ''' + """ + def __init__(self, connection, event): - ''' + """ Constructor. :param WebSocketConnection connection: :type event: wsproto.events.Request - ''' + """ self._connection = connection self._event = event @property def headers(self): - ''' + """ HTTP headers represented as a list of (name, value) pairs. :rtype: list[tuple] - ''' + """ return self._event.extra_headers @property def path(self): - ''' + """ The requested URL path. :rtype: str - ''' + """ return self._event.target @property def proposed_subprotocols(self): - ''' + """ A tuple of protocols proposed by the client. :rtype: tuple[str] - ''' + """ return tuple(self._event.subprotocols) @property def local(self): - ''' + """ The connection's local endpoint. :rtype: Endpoint or str - ''' + """ return self._connection.local @property def remote(self): - ''' + """ The connection's remote endpoint. :rtype: Endpoint or str - ''' + """ return self._connection.remote async def accept(self, *, subprotocol=None, extra_headers=None): - ''' + """ Accept the request and return a connection object. :param subprotocol: The selected subprotocol for this connection. @@ -783,14 +891,14 @@ async def accept(self, *, subprotocol=None, extra_headers=None): send as HTTP headers. :type extra_headers: list[tuple[bytes,bytes]] or None :rtype: WebSocketConnection - ''' + """ if extra_headers is None: extra_headers = [] await self._connection._accept(self._event, subprotocol, extra_headers) return self._connection async def reject(self, status_code, *, extra_headers=None, body=None): - ''' + """ Reject the handshake. :param int status_code: The 3 digit HTTP status code. In order to be @@ -801,14 +909,14 @@ async def reject(self, status_code, *, extra_headers=None, body=None): :param body: If provided, this data will be sent in the response body, otherwise no response body will be sent. :type body: bytes or None - ''' + """ extra_headers = extra_headers or [] - body = body or b'' + body = body or b"" await self._connection._reject(status_code, extra_headers, body) def _get_stream_endpoint(stream, *, local): - ''' + """ Construct an endpoint from a stream. :param trio.Stream stream: @@ -816,7 +924,7 @@ def _get_stream_endpoint(stream, *, local): :returns: An endpoint instance or ``repr()`` for streams that cannot be represented as an endpoint. :rtype: Endpoint or str - ''' + """ socket, is_ssl = None, False if isinstance(stream, trio.SocketStream): socket = stream.socket @@ -832,22 +940,23 @@ def _get_stream_endpoint(stream, *, local): class WebSocketConnection(trio.abc.AsyncResource): - ''' A WebSocket connection. ''' + """A WebSocket connection.""" CONNECTION_ID = itertools.count() def __init__( - self, - stream: trio.SocketStream | trio.SSLStream[trio.SocketStream], - ws_connection: wsproto.WSConnection, - *, - host=None, - path=None, - client_subprotocols=None, client_extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE - ): - ''' + self, + stream: trio.SocketStream | trio.SSLStream[trio.SocketStream], + ws_connection: wsproto.WSConnection, + *, + host=None, + path=None, + client_subprotocols=None, + client_extra_headers=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, + ): + """ Constructor. Generally speaking, users are discouraged from directly instantiating a @@ -872,7 +981,7 @@ def __init__( :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). - ''' + """ # NOTE: The implementation uses _close_reason for more than an advisory # purpose. It's critical internal state, indicating when the # connection is closed or closing. @@ -886,17 +995,20 @@ def __init__( self._max_message_size = max_message_size self._reader_running = True if ws_connection.client: - self._initial_request: Optional[Request] = Request(host=host, target=path, + self._initial_request: Optional[Request] = Request( + host=host, + target=path, subprotocols=client_subprotocols, - extra_headers=client_extra_headers or []) + extra_headers=client_extra_headers or [], + ) else: self._initial_request = None self._path = path self._subprotocol: Optional[str] = None - self._handshake_headers: tuple[tuple[str,str], ...] = tuple() + self._handshake_headers: tuple[tuple[str, str], ...] = tuple() self._reject_status = 0 - self._reject_headers: tuple[tuple[str,str], ...] = tuple() - self._reject_body = b'' + self._reject_headers: tuple[tuple[str, str], ...] = tuple() + self._reject_body = b"" self._send_channel, self._recv_channel = trio.open_memory_channel[ Union[bytes, str] ](message_queue_size) @@ -916,77 +1028,77 @@ def __init__( @property def closed(self): - ''' + """ (Read-only) The reason why the connection was or is being closed, else ``None``. :rtype: Optional[CloseReason] - ''' + """ return self._close_reason @property def is_client(self): - ''' (Read-only) Is this a client instance? ''' + """(Read-only) Is this a client instance?""" return self._wsproto.client @property def is_server(self): - ''' (Read-only) Is this a server instance? ''' + """(Read-only) Is this a server instance?""" return not self._wsproto.client @property def local(self): - ''' + """ The local endpoint of the connection. :rtype: Endpoint or str - ''' + """ return _get_stream_endpoint(self._stream, local=True) @property def remote(self): - ''' + """ The remote endpoint of the connection. :rtype: Endpoint or str - ''' + """ return _get_stream_endpoint(self._stream, local=False) @property def path(self): - ''' + """ The requested URL path. For clients, this is set when the connection is instantiated. For servers, it is set after the handshake completes. :rtype: str - ''' + """ return self._path @property def subprotocol(self): - ''' + """ (Read-only) The negotiated subprotocol, or ``None`` if there is no subprotocol. This is only valid after the opening handshake is complete. :rtype: str or None - ''' + """ return self._subprotocol @property def handshake_headers(self): - ''' + """ The HTTP headers that were sent by the remote during the handshake, stored as 2-tuples containing key/value pairs. Header keys are always lower case. :rtype: tuple[tuple[str,str]] - ''' + """ return self._handshake_headers async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-differ - ''' + """ Close the WebSocket connection. This sends a closing frame and suspends until the connection is closed. @@ -999,7 +1111,7 @@ async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-dif :param int code: A 4-digit code number indicating the type of closure. :param str reason: An optional string describing the closure. - ''' + """ with _preserve_current_exception(): await self._aclose(code, reason) @@ -1014,8 +1126,10 @@ async def _aclose(self, code, reason): # event to peer, while setting the local close reason to normal. self._close_reason = CloseReason(1000, None) await self._send(CloseConnection(code=code, reason=reason)) - elif self._wsproto.state in (ConnectionState.CONNECTING, - ConnectionState.REJECTING): + elif self._wsproto.state in ( + ConnectionState.CONNECTING, + ConnectionState.REJECTING, + ): self._close_handshake.set() # TODO: shouldn't the receive channel be closed earlier, so that # get_message() during send of the CloseConneciton event fails? @@ -1030,7 +1144,7 @@ async def _aclose(self, code, reason): await self._close_stream() async def get_message(self): - ''' + """ Receive the next WebSocket message. If no message is available immediately, then this function blocks until @@ -1045,15 +1159,15 @@ async def get_message(self): :rtype: str or bytes :raises ConnectionClosed: if the connection is closed. - ''' + """ try: message = await self._recv_channel.receive() except (trio.ClosedResourceError, trio.EndOfChannel): raise ConnectionClosed(self._close_reason) from None return message - async def ping(self, payload: bytes|None=None): - ''' + async def ping(self, payload: bytes | None = None): + """ Send WebSocket ping to remote endpoint and wait for a correspoding pong. Each in-flight ping must include a unique payload. This function sends @@ -1071,39 +1185,39 @@ async def ping(self, payload: bytes|None=None): :raises ConnectionClosed: if connection is closed. :raises ValueError: if ``payload`` is identical to another in-flight ping. - ''' + """ if self._close_reason: raise ConnectionClosed(self._close_reason) if payload in self._pings: - raise ValueError(f'Payload value {payload!r} is already in flight.') + raise ValueError(f"Payload value {payload!r} is already in flight.") if payload is None: - payload = struct.pack('!I', random.getrandbits(32)) + payload = struct.pack("!I", random.getrandbits(32)) event = trio.Event() self._pings[payload] = event await self._send(Ping(payload=payload)) await event.wait() async def pong(self, payload=None): - ''' + """ Send an unsolicted pong. :param payload: The pong's payload. If ``None``, then no payload is sent. :type payload: bytes or None :raises ConnectionClosed: if connection is closed - ''' + """ if self._close_reason: raise ConnectionClosed(self._close_reason) await self._send(Pong(payload=payload)) async def send_message(self, message): - ''' + """ Send a WebSocket message. :param message: The message to send. :type message: str or bytes :raises ConnectionClosed: if connection is closed, or being closed - ''' + """ if self._close_reason: raise ConnectionClosed(self._close_reason) if isinstance(message, str): @@ -1111,16 +1225,16 @@ async def send_message(self, message): elif isinstance(message, bytes): event = BytesMessage(data=message) else: - raise ValueError('message must be str or bytes') + raise ValueError("message must be str or bytes") await self._send(event) def __str__(self): - ''' Connection ID and type. ''' - type_ = 'client' if self.is_client else 'server' - return f'{type_}-{self._id}' + """Connection ID and type.""" + type_ = "client" if self.is_client else "server" + return f"{type_}-{self._id}" async def _accept(self, request, subprotocol, extra_headers): - ''' + """ Accept the handshake. This method is only applicable to server-side connections. @@ -1130,15 +1244,16 @@ async def _accept(self, request, subprotocol, extra_headers): :type subprotocol: str or None :param list[tuple[bytes,bytes]] extra_headers: A list of 2-tuples containing key/value pairs to send as HTTP headers. - ''' + """ self._subprotocol = subprotocol self._path = request.target - await self._send(AcceptConnection(subprotocol=self._subprotocol, - extra_headers=extra_headers)) + await self._send( + AcceptConnection(subprotocol=self._subprotocol, extra_headers=extra_headers) + ) self._open_handshake.set() async def _reject(self, status_code, headers, body): - ''' + """ Reject the handshake. :param int status_code: The 3 digit HTTP status code. In order to be @@ -1147,25 +1262,26 @@ async def _reject(self, status_code, headers, body): :param list[tuple[bytes,bytes]] headers: A list of 2-tuples containing key/value pairs to send as HTTP headers. :param bytes body: An optional response body. - ''' + """ if body: - headers.append(('Content-length', str(len(body)).encode('ascii'))) - reject_conn = RejectConnection(status_code=status_code, headers=headers, - has_body=bool(body)) + headers.append(("Content-length", str(len(body)).encode("ascii"))) + reject_conn = RejectConnection( + status_code=status_code, headers=headers, has_body=bool(body) + ) await self._send(reject_conn) if body: reject_body = RejectData(data=body) await self._send(reject_body) - self._close_reason = CloseReason(1006, 'Rejected WebSocket handshake') + self._close_reason = CloseReason(1006, "Rejected WebSocket handshake") self._close_handshake.set() async def _abort_web_socket(self): - ''' + """ If a stream is closed outside of this class, e.g. due to network conditions or because some other code closed our stream object, then we cannot perform the close handshake. We just need to clean up internal state. - ''' + """ close_reason = wsframeproto.CloseReason.ABNORMAL_CLOSURE if self._wsproto.state == ConnectionState.OPEN: self._wsproto.send(CloseConnection(code=close_reason.value)) @@ -1177,7 +1293,7 @@ async def _abort_web_socket(self): self._close_handshake.set() async def _close_stream(self): - ''' Close the TCP connection. ''' + """Close the TCP connection.""" self._reader_running = False try: with _preserve_current_exception(): @@ -1187,85 +1303,89 @@ async def _close_stream(self): pass async def _close_web_socket(self, code, reason=None): - ''' + """ Mark the WebSocket as closed. Close the message channel so that if any tasks are suspended in get_message(), they will wake up with a ConnectionClosed exception. - ''' + """ self._close_reason = CloseReason(code, reason) exc = ConnectionClosed(self._close_reason) - logger.debug('%s websocket closed %r', self, exc) + logger.debug("%s websocket closed %r", self, exc) await self._send_channel.aclose() async def _get_request(self): - ''' + """ Return a proposal for a WebSocket handshake. This method can only be called on server connections and it may only be called one time. :rtype: WebSocketRequest - ''' + """ if not self.is_server: - raise RuntimeError('This method is only valid for server connections.') + raise RuntimeError("This method is only valid for server connections.") if self._connection_proposal is None: - raise RuntimeError('No proposal available. Did you call this method' - ' multiple times or at the wrong time?') + raise RuntimeError( + "No proposal available. Did you call this method" + " multiple times or at the wrong time?" + ) proposal = await self._connection_proposal.wait_value() self._connection_proposal = None return proposal async def _handle_request_event(self, event): - ''' + """ Handle a connection request. This method is async even though it never awaits, because the event dispatch requires an async function. :param event: - ''' + """ proposal = WebSocketRequest(self, event) self._connection_proposal.set_value(proposal) async def _handle_accept_connection_event(self, event): - ''' + """ Handle an AcceptConnection event. :param wsproto.eventsAcceptConnection event: - ''' + """ self._subprotocol = event.subprotocol self._handshake_headers = tuple(event.extra_headers) self._open_handshake.set() async def _handle_reject_connection_event(self, event): - ''' + """ Handle a RejectConnection event. :param event: - ''' + """ self._reject_status = event.status_code self._reject_headers = tuple(event.headers) if not event.has_body: - raise ConnectionRejected(self._reject_status, self._reject_headers, - body=None) + raise ConnectionRejected( + self._reject_status, self._reject_headers, body=None + ) async def _handle_reject_data_event(self, event): - ''' + """ Handle a RejectData event. :param event: - ''' + """ self._reject_body += event.data if event.body_finished: - raise ConnectionRejected(self._reject_status, self._reject_headers, - body=self._reject_body) + raise ConnectionRejected( + self._reject_status, self._reject_headers, body=self._reject_body + ) async def _handle_close_connection_event(self, event): - ''' + """ Handle a close event. :param wsproto.events.CloseConnection event: - ''' + """ if self._wsproto.state == ConnectionState.REMOTE_CLOSING: # Set _close_reason in advance, so that send_message() will raise # ConnectionClosed during the close handshake. @@ -1282,16 +1402,16 @@ async def _handle_close_connection_event(self, event): await self._close_stream() async def _handle_message_event(self, event): - ''' + """ Handle a message event. :param event: :type event: wsproto.events.BytesMessage or wsproto.events.TextMessage - ''' + """ self._message_size += len(event.data) self._message_parts.append(event.data) if self._message_size > self._max_message_size: - err = f'Exceeded maximum message size: {self._max_message_size} bytes' + err = f"Exceeded maximum message size: {self._max_message_size} bytes" self._message_size = 0 self._message_parts = [] self._close_reason = CloseReason(1009, err) @@ -1299,8 +1419,9 @@ async def _handle_message_event(self, event): await self._recv_channel.aclose() self._reader_running = False elif event.message_finished: - msg = (b'' if isinstance(event, BytesMessage) else '') \ - .join(self._message_parts) + msg = (b"" if isinstance(event, BytesMessage) else "").join( + self._message_parts + ) self._message_size = 0 self._message_parts = [] try: @@ -1312,19 +1433,19 @@ async def _handle_message_event(self, event): pass async def _handle_ping_event(self, event): - ''' + """ Handle a PingReceived event. Wsproto queues a pong frame automatically, so this handler just needs to send it. :param wsproto.events.Ping event: - ''' - logger.debug('%s ping %r', self, event.payload) + """ + logger.debug("%s ping %r", self, event.payload) await self._send(event.response()) async def _handle_pong_event(self, event): - ''' + """ Handle a PongReceived event. When a pong is received, check if we have any ping requests waiting for @@ -1336,7 +1457,7 @@ async def _handle_pong_event(self, event): complicated if some handlers were sync. :param event: - ''' + """ payload = bytes(event.payload) try: event = self._pings[payload] @@ -1346,14 +1467,14 @@ async def _handle_pong_event(self, event): return while self._pings: key, event = self._pings.popitem(0) - skipped = ' [skipped] ' if payload != key else ' ' - logger.debug('%s pong%s%r', self, skipped, key) + skipped = " [skipped] " if payload != key else " " + logger.debug("%s pong%s%r", self, skipped, key) event.set() if payload == key: break async def _reader_task(self): - ''' A background task that reads network data and generates events. ''' + """A background task that reads network data and generates events.""" handlers = { AcceptConnection: self._handle_accept_connection_event, BytesMessage: self._handle_message_event, @@ -1380,12 +1501,12 @@ async def _reader_task(self): event_type = type(event) try: handler = handlers[event_type] - logger.debug('%s received event: %s', self, - event_type) + logger.debug("%s received event: %s", self, event_type) await handler(event) except KeyError: - logger.warning('%s received unknown event type: "%s"', self, - event_type) + logger.warning( + '%s received unknown event type: "%s"', self, event_type + ) except ConnectionClosed: self._reader_running = False break @@ -1397,27 +1518,26 @@ async def _reader_task(self): await self._abort_web_socket() break if len(data) == 0: - logger.debug('%s received zero bytes (connection closed)', - self) + logger.debug("%s received zero bytes (connection closed)", self) # If TCP closed before WebSocket, then record it as an abnormal # closure. if self._wsproto.state != ConnectionState.CLOSED: await self._abort_web_socket() break - logger.debug('%s received %d bytes', self, len(data)) + logger.debug("%s received %d bytes", self, len(data)) if self._wsproto.state != ConnectionState.CLOSED: try: self._wsproto.receive_data(data) except wsproto.utilities.RemoteProtocolError as err: - logger.debug('%s remote protocol error: %s', self, err) + logger.debug("%s remote protocol error: %s", self, err) if err.event_hint: await self._send(err.event_hint) await self._close_stream() - logger.debug('%s reader task finished', self) + logger.debug("%s reader task finished", self) async def _send(self, event): - ''' + """ Send an event to the remote WebSocket. The reader task and one or more writers might try to send messages at @@ -1425,10 +1545,10 @@ async def _send(self, event): requests to send data. :param wsproto.events.Event event: - ''' + """ data = self._wsproto.send(event) async with self._stream_lock: - logger.debug('%s sending %d bytes', self, len(data)) + logger.debug("%s sending %d bytes", self, len(data)) try: await self._stream.send_all(data) except (trio.BrokenResourceError, trio.ClosedResourceError): @@ -1437,7 +1557,8 @@ async def _send(self, event): class Endpoint: - ''' Represents a connection endpoint. ''' + """Represents a connection endpoint.""" + def __init__(self, address, port, is_ssl): #: IP address :class:`ipaddress.ip_address` self.address = ip_address(address) @@ -1448,37 +1569,43 @@ def __init__(self, address, port, is_ssl): @property def url(self): - ''' Return a URL representation of a TCP endpoint, e.g. - ``ws://127.0.0.1:80``. ''' - scheme = 'wss' if self.is_ssl else 'ws' - if (self.port == 80 and not self.is_ssl) or \ - (self.port == 443 and self.is_ssl): - port_str = '' + """Return a URL representation of a TCP endpoint, e.g. + ``ws://127.0.0.1:80``.""" + scheme = "wss" if self.is_ssl else "ws" + if (self.port == 80 and not self.is_ssl) or (self.port == 443 and self.is_ssl): + port_str = "" else: - port_str = ':' + str(self.port) + port_str = ":" + str(self.port) if self.address.version == 4: - return f'{scheme}://{self.address}{port_str}' - return f'{scheme}://[{self.address}]{port_str}' + return f"{scheme}://{self.address}{port_str}" + return f"{scheme}://[{self.address}]{port_str}" def __repr__(self): - ''' Return endpoint info as string. ''' + """Return endpoint info as string.""" return f'Endpoint(address="{self.address}", port={self.port}, is_ssl={self.is_ssl})' class WebSocketServer: - ''' + """ WebSocket server. The server class handles incoming connections on one or more ``Listener`` objects. For each incoming connection, it creates a ``WebSocketConnection`` instance and starts some background tasks, - ''' + """ - def __init__(self, handler, listeners, *, handler_nursery=None, + def __init__( + self, + handler, + listeners, + *, + handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, - disconnect_timeout=CONN_TIMEOUT): - ''' + max_message_size=MAX_MESSAGE_SIZE, + connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT, + ): + """ Constructor. Note that if ``host`` is ``None`` and ``port`` is zero, then you may get @@ -1497,9 +1624,9 @@ def __init__(self, handler, listeners, *, handler_nursery=None, to finish connection handshake before timing out. :param float disconnect_timeout: The number of seconds to wait for a client to finish the closing handshake before timing out. - ''' + """ if len(listeners) == 0: - raise ValueError('Listeners must contain at least one item.') + raise ValueError("Listeners must contain at least one item.") self._handler = handler self._handler_nursery = handler_nursery self._listeners = listeners @@ -1521,24 +1648,27 @@ def port(self): listener must be socket-based. """ if len(self._listeners) > 1: - raise RuntimeError('Cannot get port because this server has' - ' more than 1 listeners.') + raise RuntimeError( + "Cannot get port because this server has more than 1 listeners." + ) listener = self.listeners[0] try: return listener.port except AttributeError: - raise RuntimeError(f'This socket does not have a port: {repr(listener)}') from None + raise RuntimeError( + f"This socket does not have a port: {repr(listener)}" + ) from None @property def listeners(self): - ''' + """ Return a list of listener metadata. Each TCP listener is represented as an ``Endpoint`` instance. Other listener types are represented by their ``repr()``. :returns: Listeners :rtype list[Endpoint or str]: - ''' + """ listeners = [] for listener in self._listeners: socket, is_ssl = None, False @@ -1555,7 +1685,7 @@ def listeners(self): return listeners async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): - ''' + """ Start serving incoming connections requests. This method supports the Trio nursery start protocol: ``server = await @@ -1564,30 +1694,34 @@ async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): :param task_status: Part of the Trio nursery start protocol. :returns: This method never returns unless cancelled. - ''' + """ async with trio.open_nursery() as nursery: - serve_listeners = partial(trio.serve_listeners, - self._handle_connection, self._listeners, - handler_nursery=self._handler_nursery) + serve_listeners = partial( + trio.serve_listeners, + self._handle_connection, + self._listeners, + handler_nursery=self._handler_nursery, + ) await nursery.start(serve_listeners) - logger.debug('Listening on %s', - ','.join([str(l) for l in self.listeners])) + logger.debug("Listening on %s", ",".join([str(l) for l in self.listeners])) task_status.started(self) await trio.sleep_forever() async def _handle_connection(self, stream): - ''' + """ Handle an incoming connection by spawning a connection background task and a handler task inside a new nursery. :param stream: :type stream: trio.abc.Stream - ''' + """ async with trio.open_nursery() as nursery: - connection = WebSocketConnection(stream, + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.SERVER), message_queue_size=self._message_queue_size, - max_message_size=self._max_message_size) + max_message_size=self._max_message_size, + ) nursery.start_soon(connection._reader_task) with trio.move_on_after(self._connect_timeout) as connect_scope: request = await connection._get_request() diff --git a/trio_websocket/_version.py b/trio_websocket/_version.py index 2320701..5c47800 100644 --- a/trio_websocket/_version.py +++ b/trio_websocket/_version.py @@ -1 +1 @@ -__version__ = '0.12.0-dev' +__version__ = "0.12.0-dev"