Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement web.Runner context manager #8723

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGES/8723.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement web.Runner context manager -- by :user:`DavidRomanovizc`
240 changes: 195 additions & 45 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import asyncio
import contextvars
import enum
import logging
import os
import socket
import sys
import warnings
from argparse import ArgumentParser
from asyncio import constants, events, tasks
from collections.abc import Iterable
from contextlib import suppress
from importlib import import_module
Expand Down Expand Up @@ -264,9 +267,9 @@
"WSMsgType",
# web
"run_app",
"WebRunner",
)


try:
from ssl import SSLContext
except ImportError: # pragma: no cover
Expand All @@ -277,6 +280,130 @@

HostSequence = TypingIterable[str]

if sys.version_info >= (3, 11):

class _State(enum.Enum):
CREATED = "created"
INITIALIZED = "initialized"
CLOSED = "closed"

class WebRunner(asyncio.Runner): # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably get a new test that uses WebRunner directly too. Can probably just copy and modify one of the run_app() tests. The new test could also use WebRunner.run() to verify that the task and the app are run in the same loop.

"""A context manager that controls event loop life cycle"""

def __init__(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still a lot of copied code here. As far as I can tell, only the run_app() method needs to be defined here. init/close/_lazy_init can all be removed and just use the parent methods.

self,
*,
debug: Optional[bool] = None,
loop_factory: Optional[Callable[[], asyncio.AbstractEventLoop]] = None,
):
super().__init__(debug=debug, loop_factory=loop_factory)

def close(self) -> None:
"""Shutdown and close event loop."""
if self._state is not _State.INITIALIZED:
return
loop = self._loop
try:
_cancel_tasks(tasks.all_tasks(loop), loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(
loop.shutdown_default_executor(constants.THREAD_JOIN_TIMEOUT)
)
finally:
if self._set_event_loop:
events.set_event_loop(None)
loop.close()
self._loop = None
self._state = _State.CLOSED

def run_app(
self,
app: Union[Application, Awaitable[Application]],
*,
host: Optional[Union[str, HostSequence]] = None,
port: Optional[int] = None,
path: Union[PathLike, TypingIterable[PathLike], None] = None,
sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None,
shutdown_timeout: float = 60.0,
keepalive_timeout: float = 75.0,
ssl_context: Optional[SSLContext] = None,
print: Optional[Callable[..., None]] = print,
backlog: int = 128,
access_log_class: Type[AbstractAccessLogger] = AccessLogger,
access_log_format: str = AccessLogger.LOG_FORMAT,
access_log: Optional[logging.Logger] = access_logger,
handle_signals: bool = True,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
) -> None:
"""Run an app locally"""
self._lazy_init()

if (
self._loop.get_debug()
and access_log
and access_log.name == "aiohttp.access"
):
if access_log.level == logging.NOTSET:
access_log.setLevel(logging.DEBUG)
if not access_log.hasHandlers():
access_log.addHandler(logging.StreamHandler())

main_task = self._loop.create_task(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
keepalive_timeout=keepalive_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
handler_cancellation=handler_cancellation,
)
)

try:
self._loop.run_until_complete(main_task)
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
pass
finally:
_cancel_tasks({main_task}, self._loop)
self.close()

def _lazy_init(self) -> None:
if self._state is _State.CLOSED:
raise RuntimeError("Runner is closed")
if self._state is _State.INITIALIZED:
return
if self._loop_factory is None:
self._loop = events.new_event_loop()
if not self._set_event_loop:
# Call set_event_loop only once to avoid calling
# attach_loop multiple times on child watchers
events.set_event_loop(self._loop)
self._set_event_loop = True
else:
try:
self._loop = self._loop_factory()
except RuntimeError:
self._loop = events.new_event_loop()
events.set_event_loop(self._loop)
self._set_event_loop = True
if self._debug is not None:
self._loop.set_debug(self._debug)
self._context = contextvars.copy_context()
self._state = _State.INITIALIZED


async def _run_app(
app: Union[Application, Awaitable[Application]],
Expand Down Expand Up @@ -463,54 +590,77 @@ def run_app(
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""Run an app locally"""
if loop is None:
loop = asyncio.new_event_loop()
loop.set_debug(debug)

# Configure if and only if in debugging mode and using the default logger
if loop.get_debug() and access_log and access_log.name == "aiohttp.access":
if access_log.level == logging.NOTSET:
access_log.setLevel(logging.DEBUG)
if not access_log.hasHandlers():
access_log.addHandler(logging.StreamHandler())

main_task = loop.create_task(
_run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
keepalive_timeout=keepalive_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
handler_cancellation=handler_cancellation,
if sys.version_info >= (3, 11):
loop_factory = None if loop is None else lambda: loop
with WebRunner(debug=debug, loop_factory=loop_factory) as runner:
runner.run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
keepalive_timeout=keepalive_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
handler_cancellation=handler_cancellation,
)
else:
if loop is None:
loop = asyncio.new_event_loop()
loop.set_debug(debug)

# Configure if and only if in debugging mode and using the default logger
if loop.get_debug() and access_log and access_log.name == "aiohttp.access":
if access_log.level == logging.NOTSET:
access_log.setLevel(logging.DEBUG)
if not access_log.hasHandlers():
access_log.addHandler(logging.StreamHandler())

main_task = loop.create_task(
_run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
keepalive_timeout=keepalive_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
handler_cancellation=handler_cancellation,
)
)
)

try:
asyncio.set_event_loop(loop)
loop.run_until_complete(main_task)
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
pass
finally:
try:
main_task.cancel()
with suppress(asyncio.CancelledError):
loop.run_until_complete(main_task)
asyncio.set_event_loop(loop)
loop.run_until_complete(main_task)
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
pass
finally:
_cancel_tasks(asyncio.all_tasks(loop), loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
asyncio.set_event_loop(None)
try:
main_task.cancel()
with suppress(asyncio.CancelledError):
loop.run_until_complete(main_task)
finally:
_cancel_tasks(asyncio.all_tasks(loop), loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
asyncio.set_event_loop(None)


def main(argv: List[str]) -> None:
Expand Down
Loading