diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 2ac6f7f7f..f51b13f73 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -206,6 +206,7 @@ def __init__( self.status_code = status_code self.media_type = media_type self.init_headers(headers) + self.background = None async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.info is not None: @@ -222,3 +223,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "http.response.body", "body": chunk, "more_body": True}) await send({"type": "http.response.body", "body": b"", "more_body": False}) + + if self.background: + await self.background() diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 225038650..15080e5c5 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1006,16 +1006,29 @@ async def endpoint(request: Request) -> Response: @pytest.mark.anyio async def test_multiple_middlewares_stacked_client_disconnected() -> None: + """ + Tests for: + - https://github.com/encode/starlette/issues/2516 + - https://github.com/encode/starlette/pull/2687 + """ + ordered_events: list[str] = [] + unordered_events: list[str] = [] + class MyMiddleware(BaseHTTPMiddleware): - def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None: + def __init__(self, app: ASGIApp, version: int) -> None: self.version = version - self.events = events super().__init__(app) async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: - self.events.append(f"{self.version}:STARTED") + ordered_events.append(f"{self.version}:STARTED") res = await call_next(request) - self.events.append(f"{self.version}:COMPLETED") + ordered_events.append(f"{self.version}:COMPLETED") + + def background() -> None: + unordered_events.append(f"{self.version}:BACKGROUND") + + assert res.background is None + res.background = BackgroundTask(background) return res async def sleepy(request: Request) -> Response: @@ -1027,11 +1040,9 @@ async def sleepy(request: Request) -> Response: raise AssertionError("Should have raised ClientDisconnect") return Response(b"") - events: list[str] = [] - app = Starlette( routes=[Route("/", sleepy)], - middleware=[Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)], + middleware=[Middleware(MyMiddleware, version=_ + 1) for _ in range(10)], ) scope = { @@ -1051,7 +1062,7 @@ async def send(message: Message) -> None: await app(scope, receive().__anext__, send) - assert events == [ + assert ordered_events == [ "1:STARTED", "2:STARTED", "3:STARTED", @@ -1074,6 +1085,21 @@ async def send(message: Message) -> None: "1:COMPLETED", ] + assert sorted(unordered_events) == sorted( + [ + "1:BACKGROUND", + "2:BACKGROUND", + "3:BACKGROUND", + "4:BACKGROUND", + "5:BACKGROUND", + "6:BACKGROUND", + "7:BACKGROUND", + "8:BACKGROUND", + "9:BACKGROUND", + "10:BACKGROUND", + ] + ) + assert sent == [ { "type": "http.response.start",