From 17ae17ce0f738f906378acd62b4bc46e5b8f3d98 Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Thu, 22 Aug 2024 16:17:08 +0200 Subject: [PATCH 1/4] Add support for ASGI `pathsend` extension --- docs/middleware.md | 1 + starlette/middleware/gzip.py | 5 +++- starlette/responses.py | 2 ++ tests/middleware/test_gzip.py | 53 ++++++++++++++++++++++++++++++++++- tests/test_responses.py | 32 +++++++++++++++++++++ 5 files changed, 91 insertions(+), 2 deletions(-) diff --git a/docs/middleware.md b/docs/middleware.md index 9e5601819..23f0eeb84 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -264,6 +264,7 @@ around explicitly, rather than mutating the middleware instance. Currently, the `BaseHTTPMiddleware` has some known limitations: - Using `BaseHTTPMiddleware` will prevent changes to [`contextlib.ContextVar`](https://docs.python.org/3/library/contextvars.html#contextvars.ContextVar)s from propagating upwards. That is, if you set a value for a `ContextVar` in your endpoint and try to read it from a middleware you will find that the value is not the same value you set in your endpoint (see [this test](https://github.com/encode/starlette/blob/621abc747a6604825190b93467918a0ec6456a24/tests/middleware/test_base.py#L192-L223) for an example of this behavior). +- Using `BaseHTTPMiddleware` will prevent [ASGI pathsend extension](https://asgi.readthedocs.io/en/latest/extensions.html#path-send) to work properly. Thus, if you run your Starlette application with a server implementing this extension, routes returning [FileResponse](responses.md#fileresponse) should avoid the usage of this middleware. To overcome these limitations, use [pure ASGI middleware](#pure-asgi-middleware), as shown below. diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index 127b91e7a..97e12a554 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -87,7 +87,6 @@ async def send_with_gzip(self, message: Message) -> None: await self.send(self.initial_message) await self.send(message) - elif message_type == "http.response.body": # Remaining body in streaming GZip response. body = message.get("body", b"") @@ -102,6 +101,10 @@ async def send_with_gzip(self, message: Message) -> None: self.gzip_buffer.truncate() await self.send(message) + elif message_type == "http.response.pathsend": + # Don't apply GZip to pathsend responses + await self.send(self.initial_message) + await self.send(message) async def unattached_send(message: Message) -> typing.NoReturn: diff --git a/starlette/responses.py b/starlette/responses.py index 790aa7ebc..8f5574512 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -368,6 +368,8 @@ async def _handle_simple(self, send: Send, send_header_only: bool) -> None: await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) if send_header_only: await send({"type": "http.response.body", "body": b"", "more_body": False}) + elif "http.response.pathsend" in scope["extensions"]: + await send({"type": "http.response.pathsend", "path": str(self.path)}) else: async with await anyio.open_file(self.path, mode="rb") as file: more_body = True diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index b20a7cb84..4dd79af28 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -1,9 +1,21 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.requests import Request -from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse +from starlette.responses import ( + ContentStream, + FileResponse, + PlainTextResponse, + StreamingResponse, +) from starlette.routing import Route +from starlette.types import Message from tests.types import TestClientFactory @@ -104,3 +116,42 @@ async def generator(bytes: bytes, count: int) -> ContentStream: assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "text" assert "Content-Length" not in response.headers + + +@pytest.mark.anyio +async def test_gzip_ignored_for_pathsend_responses(tmpdir: Path) -> None: + path = tmpdir / "example.txt" + with path.open("w") as file: + file.write("") + + events: list[Message] = [] + + async def endpoint_with_pathsend(request: Request) -> FileResponse: + _ = await request.body() + return FileResponse(path) + + app = Starlette( + routes=[Route("/", endpoint=endpoint_with_pathsend)], + middleware=[Middleware(GZipMiddleware)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + "headers": [(b"accept-encoding", b"gzip, text")], + "extensions": {"http.response.pathsend": {}}, + } + + async def receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + async def send(message: Message) -> None: + events.append(message) + + await app(scope, receive, send) + + assert len(events) == 2 + assert events[0]["type"] == "http.response.start" + assert events[1]["type"] == "http.response.pathsend" diff --git a/tests/test_responses.py b/tests/test_responses.py index be0701a5c..c48cb9410 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -336,6 +336,38 @@ def test_file_response_with_method_warns(tmp_path: Path) -> None: FileResponse(path=tmp_path, filename="example.png", method="GET") +@pytest.mark.anyio +async def test_file_response_with_pathsend(tmpdir: Path) -> None: + path = tmpdir / "xyz" + content = b"" * 1000 + with open(path, "wb") as file: + file.write(content) + + app = FileResponse(path=path, filename="example.png") + + async def receive() -> Message: # type: ignore[empty-body] + ... # pragma: no cover + + async def send(message: Message) -> None: + if message["type"] == "http.response.start": + assert message["status"] == status.HTTP_200_OK + headers = Headers(raw=message["headers"]) + assert headers["content-type"] == "image/png" + assert "content-length" in headers + assert "content-disposition" in headers + assert "last-modified" in headers + assert "etag" in headers + elif message["type"] == "http.response.pathsend": + assert message["path"] == str(path) + + # Since the TestClient doesn't support `pathsend`, we need to test this directly. + await app( + {"type": "http", "method": "get", "extensions": {"http.response.pathsend": {}}}, + receive, + send, + ) + + def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None: # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`. mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc) From 1a27cf5bcd3f9946e3f5bbd14cf1368e0ae8df33 Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Mon, 2 Sep 2024 18:26:54 +0200 Subject: [PATCH 2/4] Add support for ASGI `pathsend` extension in `BaseHTTPMiddleware` --- docs/middleware.md | 1 - starlette/middleware/base.py | 18 ++++++++++-- tests/middleware/test_base.py | 52 ++++++++++++++++++++++++++++++++++- 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/docs/middleware.md b/docs/middleware.md index 23f0eeb84..9e5601819 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -264,7 +264,6 @@ around explicitly, rather than mutating the middleware instance. Currently, the `BaseHTTPMiddleware` has some known limitations: - Using `BaseHTTPMiddleware` will prevent changes to [`contextlib.ContextVar`](https://docs.python.org/3/library/contextvars.html#contextvars.ContextVar)s from propagating upwards. That is, if you set a value for a `ContextVar` in your endpoint and try to read it from a middleware you will find that the value is not the same value you set in your endpoint (see [this test](https://github.com/encode/starlette/blob/621abc747a6604825190b93467918a0ec6456a24/tests/middleware/test_base.py#L192-L223) for an example of this behavior). -- Using `BaseHTTPMiddleware` will prevent [ASGI pathsend extension](https://asgi.readthedocs.io/en/latest/extensions.html#path-send) to work properly. Thus, if you run your Starlette application with a server implementing this extension, routes returning [FileResponse](responses.md#fileresponse) should avoid the usage of this middleware. To overcome these limitations, use [pure ASGI middleware](#pure-asgi-middleware), as shown below. diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index f51b13f73..22c720217 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -7,11 +7,13 @@ from starlette._utils import collapse_excgroups from starlette.requests import ClientDisconnect, Request -from starlette.responses import AsyncContentStream, Response +from starlette.responses import Response from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]] +BodyStreamGenerator = typing.AsyncGenerator[typing.Union[bytes, typing.MutableMapping[str, typing.Any]], None] +AsyncContentStream = typing.AsyncIterable[typing.Union[str, bytes, memoryview, typing.MutableMapping[str, typing.Any]]] T = typing.TypeVar("T") @@ -165,9 +167,12 @@ async def coro() -> None: assert message["type"] == "http.response.start" - async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async def body_stream() -> BodyStreamGenerator: async with recv_stream: async for message in recv_stream: + if message["type"] == "http.response.pathsend": + yield message + break assert message["type"] == "http.response.body" body = message.get("body", b"") if body: @@ -219,10 +224,17 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: } ) + should_close_body = True async for chunk in self.body_iterator: + if isinstance(chunk, dict): + # We got an ASGI message which is not response body (eg: pathsend) + should_close_body = False + await send(chunk) + continue await send({"type": "http.response.body", "body": chunk, "more_body": True}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) + if should_close_body: + await send({"type": "http.response.body", "body": b"", "more_body": False}) if self.background: await self.background() diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 041cc7ce2..714a5ef05 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -2,6 +2,7 @@ import contextvars from contextlib import AsyncExitStack +from pathlib import Path from typing import Any, AsyncGenerator, AsyncIterator, Generator import anyio @@ -13,7 +14,7 @@ from starlette.middleware import Middleware, _MiddlewareClass from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import ClientDisconnect, Request -from starlette.responses import PlainTextResponse, Response, StreamingResponse +from starlette.responses import FileResponse, PlainTextResponse, Response, StreamingResponse from starlette.routing import Route, WebSocketRoute from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -1153,3 +1154,52 @@ async def send(message: Message) -> None: {"type": "http.response.body", "body": b"good!", "more_body": True}, {"type": "http.response.body", "body": b"", "more_body": False}, ] + + +@pytest.mark.anyio +async def test_asgi_pathsend_events(tmpdir: Path) -> None: + path = tmpdir / "example.txt" + with path.open("w") as file: + file.write("") + + request_body_sent = False + response_complete = anyio.Event() + events: list[Message] = [] + + async def endpoint_with_pathsend(_: Request) -> FileResponse: + return FileResponse(path) + + async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response: + return await call_next(request) + + app = Starlette( + middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], + routes=[Route("/", endpoint_with_pathsend)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + "extensions": {"http.response.pathsend": {}}, + } + + async def receive() -> Message: + nonlocal request_body_sent + if not request_body_sent: + request_body_sent = True + return {"type": "http.request", "body": b"", "more_body": False} + await response_complete.wait() + return {"type": "http.disconnect"} + + async def send(message: Message) -> None: + events.append(message) + if message["type"] == "http.response.pathsend": + response_complete.set() + + await app(scope, receive, send) + + assert len(events) == 2 + assert events[0]["type"] == "http.response.start" + assert events[1]["type"] == "http.response.pathsend" From c781d67a3ebe2909896f98a1c146d2f83f69b978 Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Mon, 2 Sep 2024 18:41:16 +0200 Subject: [PATCH 3/4] Make coverage happy again --- tests/middleware/test_base.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 714a5ef05..c3091064a 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1162,7 +1162,6 @@ async def test_asgi_pathsend_events(tmpdir: Path) -> None: with path.open("w") as file: file.write("") - request_body_sent = False response_complete = anyio.Event() events: list[Message] = [] @@ -1186,12 +1185,7 @@ async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> R } async def receive() -> Message: - nonlocal request_body_sent - if not request_body_sent: - request_body_sent = True - return {"type": "http.request", "body": b"", "more_body": False} - await response_complete.wait() - return {"type": "http.disconnect"} + raise NotImplementedError("Should not be called!") # pragma: no cover async def send(message: Message) -> None: events.append(message) From 6b658e8fb24e54fd004c4f42813a932d8eb35abd Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Thu, 3 Oct 2024 11:57:12 +0200 Subject: [PATCH 4/4] Rebase changes --- starlette/responses.py | 4 ++-- tests/middleware/test_base.py | 1 + tests/test_responses.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index 8f5574512..cc1ca3e28 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -345,7 +345,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: http_if_range = headers.get("if-range") if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range, stat_result)): - await self._handle_simple(send, send_header_only) + await self._handle_simple(scope, send, send_header_only) else: try: ranges = self._parse_range_header(http_range, stat_result.st_size) @@ -364,7 +364,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.background is not None: await self.background() - async def _handle_simple(self, send: Send, send_header_only: bool) -> None: + async def _handle_simple(self, scope: Scope, send: Send, send_header_only: bool) -> None: await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) if send_header_only: await send({"type": "http.response.body", "body": b"", "more_body": False}) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index c3091064a..40ae857a2 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1181,6 +1181,7 @@ async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> R "version": "3", "method": "GET", "path": "/", + "headers": [], "extensions": {"http.response.pathsend": {}}, } diff --git a/tests/test_responses.py b/tests/test_responses.py index c48cb9410..300a8d90f 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -362,7 +362,7 @@ async def send(message: Message) -> None: # Since the TestClient doesn't support `pathsend`, we need to test this directly. await app( - {"type": "http", "method": "get", "extensions": {"http.response.pathsend": {}}}, + {"type": "http", "method": "get", "headers": [], "extensions": {"http.response.pathsend": {}}}, receive, send, )