Skip to content

Commit

Permalink
Add a test showing the need to flush gzip_file.
Browse files Browse the repository at this point in the history
  • Loading branch information
vin committed Nov 16, 2024
1 parent f1b1ef8 commit 2450b70
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
5 changes: 3 additions & 2 deletions starlette/middleware/gzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ async def send_with_gzip(self, message: Message) -> None:
more_body = message.get("more_body", False)

self.gzip_file.write(body)
self.gzip_file.flush()
if not more_body:
if more_body:
self.gzip_file.flush()
else:
self.gzip_file.close()

message["body"] = self.gzip_buffer.getvalue()
Expand Down
26 changes: 25 additions & 1 deletion tests/middleware/test_gzip.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.requests import Request
from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse
from starlette.routing import Route
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from tests.types import TestClientFactory


Expand Down Expand Up @@ -61,6 +64,24 @@ def homepage(request: Request) -> PlainTextResponse:


def test_gzip_streaming_response(test_client_factory: TestClientFactory) -> None:
class VerifyingMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
self.received_chunks: list[bytes] = []
await self.app(scope, receive, self.sent_from_gzip)
assert all(chunk != b"" for chunk in self.received_chunks)
assert len(self.received_chunks) == 11

async def sent_from_gzip(self, message: Message) -> None:
message_type = message["type"]
if message_type == "http.response.body":
body = message.get("body", b"")
self.received_chunks.append(body)
await self.send(message)

def homepage(request: Request) -> StreamingResponse:
async def generator(bytes: bytes, count: int) -> ContentStream:
for index in range(count):
Expand All @@ -71,7 +92,10 @@ async def generator(bytes: bytes, count: int) -> ContentStream:

app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(GZipMiddleware)],
middleware=[
Middleware(VerifyingMiddleware),
Middleware(GZipMiddleware),
],
)

client = test_client_factory(app)
Expand Down

0 comments on commit 2450b70

Please sign in to comment.