Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ASGI pathsend extension #2671

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@

from starlette._utils import collapse_excgroups
from starlette.requests import ClientDisconnect, Request
from starlette.responses import AsyncContentStream, Response
from starlette.responses import Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
BodyStreamGenerator = typing.AsyncGenerator[typing.Union[bytes, typing.MutableMapping[str, typing.Any]], None]
AsyncContentStream = typing.AsyncIterable[typing.Union[str, bytes, memoryview, typing.MutableMapping[str, typing.Any]]]
T = typing.TypeVar("T")


Expand Down Expand Up @@ -165,9 +167,12 @@ async def coro() -> None:

assert message["type"] == "http.response.start"

async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async def body_stream() -> BodyStreamGenerator:
async with recv_stream:
async for message in recv_stream:
if message["type"] == "http.response.pathsend":
yield message
break
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
Expand Down Expand Up @@ -219,10 +224,17 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
}
)

should_close_body = True
async for chunk in self.body_iterator:
if isinstance(chunk, dict):
# We got an ASGI message which is not response body (eg: pathsend)
should_close_body = False
await send(chunk)
continue
await send({"type": "http.response.body", "body": chunk, "more_body": True})

await send({"type": "http.response.body", "body": b"", "more_body": False})
if should_close_body:
await send({"type": "http.response.body", "body": b"", "more_body": False})

if self.background:
await self.background()
5 changes: 4 additions & 1 deletion starlette/middleware/gzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ async def send_with_gzip(self, message: Message) -> None:

await self.send(self.initial_message)
await self.send(message)

elif message_type == "http.response.body":
# Remaining body in streaming GZip response.
body = message.get("body", b"")
Expand All @@ -102,6 +101,10 @@ async def send_with_gzip(self, message: Message) -> None:
self.gzip_buffer.truncate()

await self.send(message)
elif message_type == "http.response.pathsend":
# Don't apply GZip to pathsend responses
await self.send(self.initial_message)
await self.send(message)


async def unattached_send(message: Message) -> typing.NoReturn:
Expand Down
6 changes: 4 additions & 2 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
http_if_range = headers.get("if-range")

if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range, stat_result)):
await self._handle_simple(send, send_header_only)
await self._handle_simple(scope, send, send_header_only)
else:
try:
ranges = self._parse_range_header(http_range, stat_result.st_size)
Expand All @@ -364,10 +364,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.background is not None:
await self.background()

async def _handle_simple(self, send: Send, send_header_only: bool) -> None:
async def _handle_simple(self, scope: Scope, send: Send, send_header_only: bool) -> None:
await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
elif "http.response.pathsend" in scope["extensions"]:
await send({"type": "http.response.pathsend", "path": str(self.path)})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
more_body = True
Expand Down
47 changes: 46 additions & 1 deletion tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextvars
from contextlib import AsyncExitStack
from pathlib import Path
from typing import Any, AsyncGenerator, AsyncIterator, Generator

import anyio
Expand All @@ -13,7 +14,7 @@
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import ClientDisconnect, Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.responses import FileResponse, PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Expand Down Expand Up @@ -1153,3 +1154,47 @@ async def send(message: Message) -> None:
{"type": "http.response.body", "body": b"good!", "more_body": True},
{"type": "http.response.body", "body": b"", "more_body": False},
]


@pytest.mark.anyio
async def test_asgi_pathsend_events(tmpdir: Path) -> None:
path = tmpdir / "example.txt"
with path.open("w") as file:
file.write("<file content>")

response_complete = anyio.Event()
events: list[Message] = []

async def endpoint_with_pathsend(_: Request) -> FileResponse:
return FileResponse(path)

async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)

app = Starlette(
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
routes=[Route("/", endpoint_with_pathsend)],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
"headers": [],
"extensions": {"http.response.pathsend": {}},
}

async def receive() -> Message:
raise NotImplementedError("Should not be called!") # pragma: no cover

async def send(message: Message) -> None:
events.append(message)
if message["type"] == "http.response.pathsend":
response_complete.set()

await app(scope, receive, send)

assert len(events) == 2
assert events[0]["type"] == "http.response.start"
assert events[1]["type"] == "http.response.pathsend"
53 changes: 52 additions & 1 deletion tests/middleware/test_gzip.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from __future__ import annotations

from pathlib import Path

import pytest

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.requests import Request
from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse
from starlette.responses import (
ContentStream,
FileResponse,
PlainTextResponse,
StreamingResponse,
)
from starlette.routing import Route
from starlette.types import Message
from tests.types import TestClientFactory


Expand Down Expand Up @@ -104,3 +116,42 @@ async def generator(bytes: bytes, count: int) -> ContentStream:
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "text"
assert "Content-Length" not in response.headers


@pytest.mark.anyio
async def test_gzip_ignored_for_pathsend_responses(tmpdir: Path) -> None:
path = tmpdir / "example.txt"
with path.open("w") as file:
file.write("<file content>")

events: list[Message] = []

async def endpoint_with_pathsend(request: Request) -> FileResponse:
_ = await request.body()
return FileResponse(path)

app = Starlette(
routes=[Route("/", endpoint=endpoint_with_pathsend)],
middleware=[Middleware(GZipMiddleware)],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
"headers": [(b"accept-encoding", b"gzip, text")],
"extensions": {"http.response.pathsend": {}},
}

async def receive() -> Message:
return {"type": "http.request", "body": b"", "more_body": False}

async def send(message: Message) -> None:
events.append(message)

await app(scope, receive, send)

assert len(events) == 2
assert events[0]["type"] == "http.response.start"
assert events[1]["type"] == "http.response.pathsend"
32 changes: 32 additions & 0 deletions tests/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,38 @@ def test_file_response_with_method_warns(tmp_path: Path) -> None:
FileResponse(path=tmp_path, filename="example.png", method="GET")


@pytest.mark.anyio
async def test_file_response_with_pathsend(tmpdir: Path) -> None:
path = tmpdir / "xyz"
content = b"<file content>" * 1000
with open(path, "wb") as file:
file.write(content)

app = FileResponse(path=path, filename="example.png")

async def receive() -> Message: # type: ignore[empty-body]
... # pragma: no cover

async def send(message: Message) -> None:
if message["type"] == "http.response.start":
assert message["status"] == status.HTTP_200_OK
headers = Headers(raw=message["headers"])
assert headers["content-type"] == "image/png"
assert "content-length" in headers
assert "content-disposition" in headers
assert "last-modified" in headers
assert "etag" in headers
elif message["type"] == "http.response.pathsend":
assert message["path"] == str(path)

# Since the TestClient doesn't support `pathsend`, we need to test this directly.
await app(
{"type": "http", "method": "get", "headers": [], "extensions": {"http.response.pathsend": {}}},
receive,
send,
)


def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None:
# Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)
Expand Down