Skip to content

Commit

Permalink
Fix BackgroundTasks with BaseHTTPMiddleware (encode#2688)
Browse files Browse the repository at this point in the history
* Streaming response early disconnect mode

* Fix BackgroundTasks with BaseHTTPMiddleware

* move comment

* initialize field

---------

Co-authored-by: Dmitry Maliuga <[email protected]>
  • Loading branch information
2 people authored and Rocky Allen committed Sep 30, 2024
1 parent b1b5de4 commit fd8fb64
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
4 changes: 4 additions & 0 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
self.status_code = status_code
self.media_type = media_type
self.init_headers(headers)
self.background = None

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.info is not None:
Expand All @@ -222,3 +223,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "http.response.body", "body": chunk, "more_body": True})

await send({"type": "http.response.body", "body": b"", "more_body": False})

if self.background:
await self.background()
42 changes: 34 additions & 8 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,16 +1006,29 @@ async def endpoint(request: Request) -> Response:

@pytest.mark.anyio
async def test_multiple_middlewares_stacked_client_disconnected() -> None:
"""
Tests for:
- https://github.com/encode/starlette/issues/2516
- https://github.com/encode/starlette/pull/2687
"""
ordered_events: list[str] = []
unordered_events: list[str] = []

class MyMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None:
def __init__(self, app: ASGIApp, version: int) -> None:
self.version = version
self.events = events
super().__init__(app)

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
self.events.append(f"{self.version}:STARTED")
ordered_events.append(f"{self.version}:STARTED")
res = await call_next(request)
self.events.append(f"{self.version}:COMPLETED")
ordered_events.append(f"{self.version}:COMPLETED")

def background() -> None:
unordered_events.append(f"{self.version}:BACKGROUND")

assert res.background is None
res.background = BackgroundTask(background)
return res

async def sleepy(request: Request) -> Response:
Expand All @@ -1027,11 +1040,9 @@ async def sleepy(request: Request) -> Response:
raise AssertionError("Should have raised ClientDisconnect")
return Response(b"")

events: list[str] = []

app = Starlette(
routes=[Route("/", sleepy)],
middleware=[Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)],
middleware=[Middleware(MyMiddleware, version=_ + 1) for _ in range(10)],
)

scope = {
Expand All @@ -1051,7 +1062,7 @@ async def send(message: Message) -> None:

await app(scope, receive().__anext__, send)

assert events == [
assert ordered_events == [
"1:STARTED",
"2:STARTED",
"3:STARTED",
Expand All @@ -1074,6 +1085,21 @@ async def send(message: Message) -> None:
"1:COMPLETED",
]

assert sorted(unordered_events) == sorted(
[
"1:BACKGROUND",
"2:BACKGROUND",
"3:BACKGROUND",
"4:BACKGROUND",
"5:BACKGROUND",
"6:BACKGROUND",
"7:BACKGROUND",
"8:BACKGROUND",
"9:BACKGROUND",
"10:BACKGROUND",
]
)

assert sent == [
{
"type": "http.response.start",
Expand Down

0 comments on commit fd8fb64

Please sign in to comment.