Skip to content

Commit

Permalink
Add Mount(..., middleware=[...]) (#1649)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <[email protected]>
Co-authored-by: Florimond Manca <[email protected]>
Co-authored-by: Aber <[email protected]>
  • Loading branch information
4 people authored Sep 21, 2022
1 parent bc61505 commit ef34ece
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 3 deletions.
35 changes: 35 additions & 0 deletions docs/middleware.md
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,41 @@ to use the `middleware=<List of Middleware instances>` style, as it will:
* Ensure that everything remains wrapped in a single outermost `ServerErrorMiddleware`.
* Preserves the top-level `app` instance.

## Applying middleware to `Mount`s

Middleware can also be added to `Mount`, which allows you to apply middleware to a single route, a group of routes or any mounted ASGI application:

```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.routing import Mount, Route


routes = [
Mount(
"/",
routes=[
Route(
"/example",
endpoint=...,
)
],
middleware=[Middleware(GZipMiddleware)]
)
]

app = Starlette(routes=routes)
```

Note that middleware used in this way is *not* wrapped in exception handling middleware like the middleware applied to the `Starlette` application is.
This is often not a problem because it only applies to middleware that inspect or modify the `Response`, and even then you probably don't want to apply this logic to error responses.
If you do want to apply the middleware logic to error responses only on some routes you have a couple of options:

* Add an `ExceptionMiddleware` onto the `Mount`
* Add a `try/except` block to your middleware and return an error response from there
* Split up marking and processing into two middlewares, one that gets put on `Mount` which marks the response as needing processing (for example by setting `scope["log-response"] = True`) and another applied to the `Starlette` application that does the heavy lifting.

## Third party middleware

#### [asgi-auth-github](https://github.com/simonw/asgi-auth-github)
Expand Down
13 changes: 10 additions & 3 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from starlette.convertors import CONVERTOR_TYPES, Convertor
from starlette.datastructures import URL, Headers, URLPath
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Receive, Scope, Send
Expand Down Expand Up @@ -348,24 +349,30 @@ def __init__(
app: typing.Optional[ASGIApp] = None,
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
name: typing.Optional[str] = None,
*,
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
) -> None:
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
assert (
app is not None or routes is not None
), "Either 'app=...', or 'routes=' must be specified"
self.path = path.rstrip("/")
if app is not None:
self.app: ASGIApp = app
self._base_app: ASGIApp = app
else:
self.app = Router(routes=routes)
self._base_app = Router(routes=routes)
self.app = self._base_app
if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)
self.name = name
self.path_regex, self.path_format, self.param_convertors = compile_path(
self.path + "/{path:path}"
)

@property
def routes(self) -> typing.List[BaseRoute]:
return getattr(self.app, "routes", [])
return getattr(self._base_app, "routes", [])

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
if scope["type"] in ("http", "websocket"):
Expand Down
173 changes: 173 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
import pytest

from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect


Expand Down Expand Up @@ -768,6 +773,115 @@ def test_route_name(endpoint: typing.Callable, expected_name: str):
assert Route(path="/", endpoint=endpoint).name == expected_name


class AddHeadersMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["add_headers_middleware"] = True

async def modified_send(msg: Message) -> None:
if msg["type"] == "http.response.start":
msg["headers"].append((b"X-Test", b"Set by middleware"))
await send(msg)

await self.app(scope, receive, modified_send)


def assert_middleware_header_route(request: Request) -> Response:
assert request.scope["add_headers_middleware"] is True
return Response()


mounted_routes_with_middleware = Starlette(
routes=[
Mount(
"/http",
routes=[
Route(
"/",
endpoint=assert_middleware_header_route,
methods=["GET"],
name="route",
),
],
middleware=[Middleware(AddHeadersMiddleware)],
),
Route("/home", homepage),
]
)


mounted_app_with_middleware = Starlette(
routes=[
Mount(
"/http",
app=Route(
"/",
endpoint=assert_middleware_header_route,
methods=["GET"],
name="route",
),
middleware=[Middleware(AddHeadersMiddleware)],
),
Route("/home", homepage),
]
)


@pytest.mark.parametrize(
"app",
[
mounted_routes_with_middleware,
mounted_app_with_middleware,
],
)
def test_mount_middleware(
test_client_factory: typing.Callable[..., TestClient],
app: Starlette,
) -> None:
test_client = test_client_factory(app)

response = test_client.get("/home")
assert response.status_code == 200
assert "X-Test" not in response.headers

response = test_client.get("/http")
assert response.status_code == 200
assert response.headers["X-Test"] == "Set by middleware"


def test_mount_routes_with_middleware_url_path_for() -> None:
"""Checks that url_path_for still works with mounted routes with Middleware"""
assert mounted_routes_with_middleware.url_path_for("route") == "/http/"


def test_mount_asgi_app_with_middleware_url_path_for() -> None:
"""Mounted ASGI apps do not work with url path for,
middleware does not change this
"""
with pytest.raises(NoMatchFound):
mounted_app_with_middleware.url_path_for("route")


def test_add_route_to_app_after_mount(
test_client_factory: typing.Callable[..., TestClient],
) -> None:
"""Checks that Mount will pick up routes
added to the underlying app after it is mounted
"""
inner_app = Router()
app = Mount("/http", app=inner_app)
inner_app.add_route(
"/inner",
endpoint=homepage,
methods=["GET"],
)
client = test_client_factory(app)
response = client.get("/http/inner")
assert response.status_code == 200


def test_exception_on_mounted_apps(test_client_factory):
def exc(request):
raise Exception("Exc")
Expand All @@ -779,3 +893,62 @@ def exc(request):
with pytest.raises(Exception) as ctx:
client.get("/sub/")
assert str(ctx.value) == "Exc"


def test_mounted_middleware_does_not_catch_exception(
test_client_factory: typing.Callable[..., TestClient],
) -> None:
# https://github.com/encode/starlette/pull/1649#discussion_r960236107
def exc(request: Request) -> Response:
raise HTTPException(status_code=403, detail="auth")

class NamedMiddleware:
def __init__(self, app: ASGIApp, name: str) -> None:
self.app = app
self.name = name

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async def modified_send(msg: Message) -> None:
if msg["type"] == "http.response.start":
msg["headers"].append((f"X-{self.name}".encode(), b"true"))
await send(msg)

await self.app(scope, receive, modified_send)

app = Starlette(
routes=[
Mount(
"/mount",
routes=[
Route("/err", exc),
Route("/home", homepage),
],
middleware=[Middleware(NamedMiddleware, name="Mounted")],
),
Route("/err", exc),
Route("/home", homepage),
],
middleware=[Middleware(NamedMiddleware, name="Outer")],
)

client = test_client_factory(app)

resp = client.get("/home")
assert resp.status_code == 200, resp.content
assert "X-Outer" in resp.headers

resp = client.get("/err")
assert resp.status_code == 403, resp.content
assert "X-Outer" in resp.headers

resp = client.get("/mount/home")
assert resp.status_code == 200, resp.content
assert "X-Mounted" in resp.headers

# this is the "surprising" behavior bit
# the middleware on the mount never runs because there
# is nothing to catch the HTTPException
# since Mount middlweare is not wrapped by ExceptionMiddleware
resp = client.get("/mount/err")
assert resp.status_code == 403, resp.content
assert "X-Mounted" not in resp.headers

0 comments on commit ef34ece

Please sign in to comment.