-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement web.Runner
context manager
#8723
base: master
Are you sure you want to change the base?
Changes from all commits
57afbfc
283570f
77ec5eb
50c1a6b
3df86f3
a917375
7db4f09
1684a34
722631c
1f733e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Implement web.Runner context manager -- by :user:`DavidRomanovizc` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,13 @@ | ||
import asyncio | ||
import contextvars | ||
import enum | ||
import logging | ||
import os | ||
import socket | ||
import sys | ||
import warnings | ||
from argparse import ArgumentParser | ||
from asyncio import constants, events, tasks | ||
from collections.abc import Iterable | ||
from contextlib import suppress | ||
from importlib import import_module | ||
|
@@ -264,9 +267,9 @@ | |
"WSMsgType", | ||
# web | ||
"run_app", | ||
"WebRunner", | ||
) | ||
|
||
|
||
try: | ||
from ssl import SSLContext | ||
except ImportError: # pragma: no cover | ||
|
@@ -277,6 +280,130 @@ | |
|
||
HostSequence = TypingIterable[str] | ||
|
||
if sys.version_info >= (3, 11): | ||
|
||
class _State(enum.Enum): | ||
CREATED = "created" | ||
INITIALIZED = "initialized" | ||
CLOSED = "closed" | ||
|
||
class WebRunner(asyncio.Runner): # type: ignore | ||
"""A context manager that controls event loop life cycle""" | ||
|
||
def __init__( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still a lot of copied code here. As far as I can tell, only the run_app() method needs to be defined here. init/close/_lazy_init can all be removed and just use the parent methods. |
||
self, | ||
*, | ||
debug: Optional[bool] = None, | ||
loop_factory: Optional[Callable[[], asyncio.AbstractEventLoop]] = None, | ||
): | ||
super().__init__(debug=debug, loop_factory=loop_factory) | ||
|
||
def close(self) -> None: | ||
"""Shutdown and close event loop.""" | ||
if self._state is not _State.INITIALIZED: | ||
return | ||
loop = self._loop | ||
try: | ||
_cancel_tasks(tasks.all_tasks(loop), loop) | ||
loop.run_until_complete(loop.shutdown_asyncgens()) | ||
loop.run_until_complete( | ||
loop.shutdown_default_executor(constants.THREAD_JOIN_TIMEOUT) | ||
) | ||
finally: | ||
if self._set_event_loop: | ||
events.set_event_loop(None) | ||
loop.close() | ||
self._loop = None | ||
self._state = _State.CLOSED | ||
|
||
def run_app( | ||
self, | ||
app: Union[Application, Awaitable[Application]], | ||
*, | ||
host: Optional[Union[str, HostSequence]] = None, | ||
port: Optional[int] = None, | ||
path: Union[PathLike, TypingIterable[PathLike], None] = None, | ||
sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None, | ||
shutdown_timeout: float = 60.0, | ||
keepalive_timeout: float = 75.0, | ||
ssl_context: Optional[SSLContext] = None, | ||
print: Optional[Callable[..., None]] = print, | ||
backlog: int = 128, | ||
access_log_class: Type[AbstractAccessLogger] = AccessLogger, | ||
access_log_format: str = AccessLogger.LOG_FORMAT, | ||
access_log: Optional[logging.Logger] = access_logger, | ||
handle_signals: bool = True, | ||
reuse_address: Optional[bool] = None, | ||
reuse_port: Optional[bool] = None, | ||
handler_cancellation: bool = False, | ||
) -> None: | ||
"""Run an app locally""" | ||
self._lazy_init() | ||
|
||
if ( | ||
self._loop.get_debug() | ||
and access_log | ||
and access_log.name == "aiohttp.access" | ||
): | ||
if access_log.level == logging.NOTSET: | ||
access_log.setLevel(logging.DEBUG) | ||
if not access_log.hasHandlers(): | ||
access_log.addHandler(logging.StreamHandler()) | ||
|
||
main_task = self._loop.create_task( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should probably add the context parameter for compatibility: |
||
_run_app( | ||
app, | ||
host=host, | ||
port=port, | ||
path=path, | ||
sock=sock, | ||
shutdown_timeout=shutdown_timeout, | ||
keepalive_timeout=keepalive_timeout, | ||
ssl_context=ssl_context, | ||
print=print, | ||
backlog=backlog, | ||
access_log_class=access_log_class, | ||
access_log_format=access_log_format, | ||
access_log=access_log, | ||
handle_signals=handle_signals, | ||
reuse_address=reuse_address, | ||
reuse_port=reuse_port, | ||
handler_cancellation=handler_cancellation, | ||
) | ||
) | ||
|
||
try: | ||
self._loop.run_until_complete(main_task) | ||
except (GracefulExit, KeyboardInterrupt): # pragma: no cover | ||
pass | ||
finally: | ||
_cancel_tasks({main_task}, self._loop) | ||
self.close() | ||
|
||
def _lazy_init(self) -> None: | ||
if self._state is _State.CLOSED: | ||
raise RuntimeError("Runner is closed") | ||
if self._state is _State.INITIALIZED: | ||
return | ||
if self._loop_factory is None: | ||
self._loop = events.new_event_loop() | ||
if not self._set_event_loop: | ||
# Call set_event_loop only once to avoid calling | ||
# attach_loop multiple times on child watchers | ||
events.set_event_loop(self._loop) | ||
self._set_event_loop = True | ||
else: | ||
try: | ||
self._loop = self._loop_factory() | ||
except RuntimeError: | ||
self._loop = events.new_event_loop() | ||
events.set_event_loop(self._loop) | ||
self._set_event_loop = True | ||
if self._debug is not None: | ||
self._loop.set_debug(self._debug) | ||
self._context = contextvars.copy_context() | ||
self._state = _State.INITIALIZED | ||
|
||
|
||
async def _run_app( | ||
app: Union[Application, Awaitable[Application]], | ||
|
@@ -463,54 +590,77 @@ def run_app( | |
loop: Optional[asyncio.AbstractEventLoop] = None, | ||
) -> None: | ||
"""Run an app locally""" | ||
if loop is None: | ||
loop = asyncio.new_event_loop() | ||
loop.set_debug(debug) | ||
|
||
# Configure if and only if in debugging mode and using the default logger | ||
if loop.get_debug() and access_log and access_log.name == "aiohttp.access": | ||
if access_log.level == logging.NOTSET: | ||
access_log.setLevel(logging.DEBUG) | ||
if not access_log.hasHandlers(): | ||
access_log.addHandler(logging.StreamHandler()) | ||
|
||
main_task = loop.create_task( | ||
_run_app( | ||
app, | ||
host=host, | ||
port=port, | ||
path=path, | ||
sock=sock, | ||
shutdown_timeout=shutdown_timeout, | ||
keepalive_timeout=keepalive_timeout, | ||
ssl_context=ssl_context, | ||
print=print, | ||
backlog=backlog, | ||
access_log_class=access_log_class, | ||
access_log_format=access_log_format, | ||
access_log=access_log, | ||
handle_signals=handle_signals, | ||
reuse_address=reuse_address, | ||
reuse_port=reuse_port, | ||
handler_cancellation=handler_cancellation, | ||
if sys.version_info >= (3, 11): | ||
loop_factory = None if loop is None else lambda: loop | ||
with WebRunner(debug=debug, loop_factory=loop_factory) as runner: | ||
runner.run_app( | ||
app, | ||
host=host, | ||
port=port, | ||
path=path, | ||
sock=sock, | ||
shutdown_timeout=shutdown_timeout, | ||
keepalive_timeout=keepalive_timeout, | ||
ssl_context=ssl_context, | ||
print=print, | ||
backlog=backlog, | ||
access_log_class=access_log_class, | ||
access_log_format=access_log_format, | ||
access_log=access_log, | ||
handle_signals=handle_signals, | ||
reuse_address=reuse_address, | ||
reuse_port=reuse_port, | ||
handler_cancellation=handler_cancellation, | ||
) | ||
else: | ||
if loop is None: | ||
loop = asyncio.new_event_loop() | ||
loop.set_debug(debug) | ||
|
||
# Configure if and only if in debugging mode and using the default logger | ||
if loop.get_debug() and access_log and access_log.name == "aiohttp.access": | ||
if access_log.level == logging.NOTSET: | ||
access_log.setLevel(logging.DEBUG) | ||
if not access_log.hasHandlers(): | ||
access_log.addHandler(logging.StreamHandler()) | ||
|
||
main_task = loop.create_task( | ||
_run_app( | ||
app, | ||
host=host, | ||
port=port, | ||
path=path, | ||
sock=sock, | ||
shutdown_timeout=shutdown_timeout, | ||
keepalive_timeout=keepalive_timeout, | ||
ssl_context=ssl_context, | ||
print=print, | ||
backlog=backlog, | ||
access_log_class=access_log_class, | ||
access_log_format=access_log_format, | ||
access_log=access_log, | ||
handle_signals=handle_signals, | ||
reuse_address=reuse_address, | ||
reuse_port=reuse_port, | ||
handler_cancellation=handler_cancellation, | ||
) | ||
) | ||
) | ||
|
||
try: | ||
asyncio.set_event_loop(loop) | ||
loop.run_until_complete(main_task) | ||
except (GracefulExit, KeyboardInterrupt): # pragma: no cover | ||
pass | ||
finally: | ||
try: | ||
main_task.cancel() | ||
with suppress(asyncio.CancelledError): | ||
loop.run_until_complete(main_task) | ||
asyncio.set_event_loop(loop) | ||
loop.run_until_complete(main_task) | ||
except (GracefulExit, KeyboardInterrupt): # pragma: no cover | ||
pass | ||
finally: | ||
_cancel_tasks(asyncio.all_tasks(loop), loop) | ||
loop.run_until_complete(loop.shutdown_asyncgens()) | ||
loop.close() | ||
asyncio.set_event_loop(None) | ||
try: | ||
main_task.cancel() | ||
with suppress(asyncio.CancelledError): | ||
loop.run_until_complete(main_task) | ||
finally: | ||
_cancel_tasks(asyncio.all_tasks(loop), loop) | ||
loop.run_until_complete(loop.shutdown_asyncgens()) | ||
loop.close() | ||
asyncio.set_event_loop(None) | ||
|
||
|
||
def main(argv: List[str]) -> None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably get a new test that uses WebRunner directly too. Can probably just copy and modify one of the run_app() tests. The new test could also use WebRunner.run() to verify that the task and the app are run in the same loop.