diff --git a/CHANGES/8723.feature.rst b/CHANGES/8723.feature.rst new file mode 100644 index 00000000000..59fc945e45a --- /dev/null +++ b/CHANGES/8723.feature.rst @@ -0,0 +1 @@ +Implement web.Runner context manager -- by :user:`DavidRomanovizc` diff --git a/aiohttp/web.py b/aiohttp/web.py index 39b9b6bfde5..6a996d1135d 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -1,10 +1,13 @@ import asyncio +import contextvars +import enum import logging import os import socket import sys import warnings from argparse import ArgumentParser +from asyncio import constants, events, tasks from collections.abc import Iterable from contextlib import suppress from importlib import import_module @@ -264,9 +267,9 @@ "WSMsgType", # web "run_app", + "WebRunner", ) - try: from ssl import SSLContext except ImportError: # pragma: no cover @@ -277,6 +280,130 @@ HostSequence = TypingIterable[str] +if sys.version_info >= (3, 11): + + class _State(enum.Enum): + CREATED = "created" + INITIALIZED = "initialized" + CLOSED = "closed" + + class WebRunner(asyncio.Runner): # type: ignore + """A context manager that controls event loop life cycle""" + + def __init__( + self, + *, + debug: Optional[bool] = None, + loop_factory: Optional[Callable[[], asyncio.AbstractEventLoop]] = None, + ): + super().__init__(debug=debug, loop_factory=loop_factory) + + def close(self) -> None: + """Shutdown and close event loop.""" + if self._state is not _State.INITIALIZED: + return + loop = self._loop + try: + _cancel_tasks(tasks.all_tasks(loop), loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.run_until_complete( + loop.shutdown_default_executor(constants.THREAD_JOIN_TIMEOUT) + ) + finally: + if self._set_event_loop: + events.set_event_loop(None) + loop.close() + self._loop = None + self._state = _State.CLOSED + + def run_app( + self, + app: Union[Application, Awaitable[Application]], + *, + host: Optional[Union[str, HostSequence]] = None, + port: Optional[int] = None, + path: Union[PathLike, TypingIterable[PathLike], None] = None, + sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None, + shutdown_timeout: float = 60.0, + keepalive_timeout: float = 75.0, + ssl_context: Optional[SSLContext] = None, + print: Optional[Callable[..., None]] = print, + backlog: int = 128, + access_log_class: Type[AbstractAccessLogger] = AccessLogger, + access_log_format: str = AccessLogger.LOG_FORMAT, + access_log: Optional[logging.Logger] = access_logger, + handle_signals: bool = True, + reuse_address: Optional[bool] = None, + reuse_port: Optional[bool] = None, + handler_cancellation: bool = False, + ) -> None: + """Run an app locally""" + self._lazy_init() + + if ( + self._loop.get_debug() + and access_log + and access_log.name == "aiohttp.access" + ): + if access_log.level == logging.NOTSET: + access_log.setLevel(logging.DEBUG) + if not access_log.hasHandlers(): + access_log.addHandler(logging.StreamHandler()) + + main_task = self._loop.create_task( + _run_app( + app, + host=host, + port=port, + path=path, + sock=sock, + shutdown_timeout=shutdown_timeout, + keepalive_timeout=keepalive_timeout, + ssl_context=ssl_context, + print=print, + backlog=backlog, + access_log_class=access_log_class, + access_log_format=access_log_format, + access_log=access_log, + handle_signals=handle_signals, + reuse_address=reuse_address, + reuse_port=reuse_port, + handler_cancellation=handler_cancellation, + ) + ) + + try: + self._loop.run_until_complete(main_task) + except (GracefulExit, KeyboardInterrupt): # pragma: no cover + pass + finally: + _cancel_tasks({main_task}, self._loop) + self.close() + + def _lazy_init(self) -> None: + if self._state is _State.CLOSED: + raise RuntimeError("Runner is closed") + if self._state is _State.INITIALIZED: + return + if self._loop_factory is None: + self._loop = events.new_event_loop() + if not self._set_event_loop: + # Call set_event_loop only once to avoid calling + # attach_loop multiple times on child watchers + events.set_event_loop(self._loop) + self._set_event_loop = True + else: + try: + self._loop = self._loop_factory() + except RuntimeError: + self._loop = events.new_event_loop() + events.set_event_loop(self._loop) + self._set_event_loop = True + if self._debug is not None: + self._loop.set_debug(self._debug) + self._context = contextvars.copy_context() + self._state = _State.INITIALIZED + async def _run_app( app: Union[Application, Awaitable[Application]], @@ -463,54 +590,77 @@ def run_app( loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: """Run an app locally""" - if loop is None: - loop = asyncio.new_event_loop() - loop.set_debug(debug) - - # Configure if and only if in debugging mode and using the default logger - if loop.get_debug() and access_log and access_log.name == "aiohttp.access": - if access_log.level == logging.NOTSET: - access_log.setLevel(logging.DEBUG) - if not access_log.hasHandlers(): - access_log.addHandler(logging.StreamHandler()) - - main_task = loop.create_task( - _run_app( - app, - host=host, - port=port, - path=path, - sock=sock, - shutdown_timeout=shutdown_timeout, - keepalive_timeout=keepalive_timeout, - ssl_context=ssl_context, - print=print, - backlog=backlog, - access_log_class=access_log_class, - access_log_format=access_log_format, - access_log=access_log, - handle_signals=handle_signals, - reuse_address=reuse_address, - reuse_port=reuse_port, - handler_cancellation=handler_cancellation, + if sys.version_info >= (3, 11): + loop_factory = None if loop is None else lambda: loop + with WebRunner(debug=debug, loop_factory=loop_factory) as runner: + runner.run_app( + app, + host=host, + port=port, + path=path, + sock=sock, + shutdown_timeout=shutdown_timeout, + keepalive_timeout=keepalive_timeout, + ssl_context=ssl_context, + print=print, + backlog=backlog, + access_log_class=access_log_class, + access_log_format=access_log_format, + access_log=access_log, + handle_signals=handle_signals, + reuse_address=reuse_address, + reuse_port=reuse_port, + handler_cancellation=handler_cancellation, + ) + else: + if loop is None: + loop = asyncio.new_event_loop() + loop.set_debug(debug) + + # Configure if and only if in debugging mode and using the default logger + if loop.get_debug() and access_log and access_log.name == "aiohttp.access": + if access_log.level == logging.NOTSET: + access_log.setLevel(logging.DEBUG) + if not access_log.hasHandlers(): + access_log.addHandler(logging.StreamHandler()) + + main_task = loop.create_task( + _run_app( + app, + host=host, + port=port, + path=path, + sock=sock, + shutdown_timeout=shutdown_timeout, + keepalive_timeout=keepalive_timeout, + ssl_context=ssl_context, + print=print, + backlog=backlog, + access_log_class=access_log_class, + access_log_format=access_log_format, + access_log=access_log, + handle_signals=handle_signals, + reuse_address=reuse_address, + reuse_port=reuse_port, + handler_cancellation=handler_cancellation, + ) ) - ) - try: - asyncio.set_event_loop(loop) - loop.run_until_complete(main_task) - except (GracefulExit, KeyboardInterrupt): # pragma: no cover - pass - finally: try: - main_task.cancel() - with suppress(asyncio.CancelledError): - loop.run_until_complete(main_task) + asyncio.set_event_loop(loop) + loop.run_until_complete(main_task) + except (GracefulExit, KeyboardInterrupt): # pragma: no cover + pass finally: - _cancel_tasks(asyncio.all_tasks(loop), loop) - loop.run_until_complete(loop.shutdown_asyncgens()) - loop.close() - asyncio.set_event_loop(None) + try: + main_task.cancel() + with suppress(asyncio.CancelledError): + loop.run_until_complete(main_task) + finally: + _cancel_tasks(asyncio.all_tasks(loop), loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() + asyncio.set_event_loop(None) def main(argv: List[str]) -> None: