diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 42ec7a9ea..334272a72 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -20,6 +20,11 @@ class Address(typing.NamedTuple): _CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True) +class TrustedHost(bytes): + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + class URL: def __init__( self, diff --git a/starlette/exceptions.py b/starlette/exceptions.py index 87da73591..5ffaf15d7 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -33,6 +33,10 @@ def __repr__(self) -> str: return f"{class_name}(code={self.code!r}, reason={self.reason!r})" +class ImproperlyConfigured(Exception): + pass + + __deprecated__ = "ExceptionMiddleware" diff --git a/starlette/middleware/trustedhost.py b/starlette/middleware/trustedhost.py index e84e6876a..1f9a8927c 100644 --- a/starlette/middleware/trustedhost.py +++ b/starlette/middleware/trustedhost.py @@ -1,6 +1,7 @@ +import copy import typing -from starlette.datastructures import URL, Headers +from starlette.datastructures import URL, Headers, TrustedHost from starlette.responses import PlainTextResponse, RedirectResponse, Response from starlette.types import ASGIApp, Receive, Scope, Send @@ -31,6 +32,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: "http", "websocket", ): # pragma: no cover + scope = self._mark_host_header_as_trusted(scope) await self.app(scope, receive, send) return @@ -48,6 +50,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: found_www_redirect = True if is_valid_host: + scope = self._mark_host_header_as_trusted(scope) await self.app(scope, receive, send) else: response: Response @@ -58,3 +61,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: else: response = PlainTextResponse("Invalid host header", status_code=400) await response(scope, receive, send) + + def _mark_host_header_as_trusted(self, scope: Scope) -> Scope: + if "headers" not in scope: + return scope + new_scope = copy.copy(scope) + new_scope["headers"] = [ + (key, value if key != b"host" else TrustedHost(value)) + for key, value in new_scope["headers"] + ] + return new_scope diff --git a/starlette/requests.py b/starlette/requests.py index 726abddcc..b6231cc63 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -4,8 +4,16 @@ import anyio -from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State -from starlette.exceptions import HTTPException +from starlette.datastructures import ( + URL, + Address, + FormData, + Headers, + QueryParams, + State, + TrustedHost, +) +from starlette.exceptions import HTTPException, ImproperlyConfigured from starlette.formparsers import FormParser, MultiPartException, MultiPartParser from starlette.types import Message, Receive, Scope, Send @@ -103,6 +111,13 @@ def base_url(self) -> URL: base_url_scope["root_path"] = base_url_scope.get( "app_root_path", base_url_scope.get("root_path", "") ) + for key, value in base_url_scope["headers"]: + if key == b"host" and not isinstance(value, TrustedHost): + raise ImproperlyConfigured( + "No trusted host header configuration found, you need " + "to use TrustedHostMiddleware(allowed_hosts=[...]) " + "if you want to generate absolute URL." + ) self._base_url = URL(scope=base_url_scope) return self._base_url diff --git a/starlette/testclient.py b/starlette/testclient.py index 455440ce5..a0c2e274c 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -16,6 +16,7 @@ from anyio.streams.stapled import StapledObjectStream from starlette._utils import is_async_callable +from starlette.datastructures import TrustedHost from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect @@ -210,15 +211,19 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: if "host" in request.headers: headers: typing.List[typing.Tuple[bytes, bytes]] = [] elif port == default_port: # pragma: no cover - headers = [(b"host", host.encode())] + headers = [(b"host", TrustedHost(host.encode()))] else: # pragma: no cover - headers = [(b"host", (f"{host}:{port}").encode())] + headers = [(b"host", TrustedHost((f"{host}:{port}").encode()))] # Include other request headers. headers += [ (key.lower().encode(), value.encode()) for key, value in request.headers.items() ] + headers = [ + (key, (value if key != b"host" else TrustedHost(value))) + for key, value in headers + ] scope: typing.Dict[str, typing.Any] diff --git a/tests/test_applications.py b/tests/test_applications.py index 2cee601b0..e673b557e 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -7,12 +7,13 @@ from starlette import status from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint -from starlette.exceptions import HTTPException, WebSocketException +from starlette.exceptions import HTTPException, ImproperlyConfigured, WebSocketException from starlette.middleware import Middleware from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles +from starlette.types import Message from starlette.websockets import WebSocket @@ -429,3 +430,131 @@ def lifespan(app): assert not cleanup_complete assert startup_complete assert cleanup_complete + + +@pytest.mark.anyio +async def test_trusted_host_not_configured(): + async def async_url_for(request): + return PlainTextResponse(request.url_for("async_url_for")) + + app = Starlette( + routes=[ + Route("/func", endpoint=async_url_for), + ], + ) + scope = { + "type": "http", + "headers": { + (b"host", b"testserver"), + }, + "method": "GET", + "path": "/func", + } + + async def receive() -> Message: # pragma: no cover + assert False + + async def send(message): + if message["type"] == "http.response.start": + assert message["status"] == 500 + elif message["type"] == "http.response.body": + assert message["body"] == b"Internal Server Error" + + request = app(scope, receive, send) + with pytest.raises(ImproperlyConfigured): + await request + + +@pytest.mark.anyio +async def test_trusted_host_wildcard(): + async def async_url_for(request): + return PlainTextResponse(request.url_for("async_url_for")) + + app = Starlette( + routes=[ + Route("/func", endpoint=async_url_for), + ], + ) + scope = { + "type": "http", + "headers": { + (b"host", b"testserver"), + }, + "method": "GET", + "path": "/func", + } + + async def receive() -> Message: # pragma: no cover + assert False + + async def send(message): + if message["type"] == "http.response.start": + assert message["status"] == 200 + + app.add_middleware(TrustedHostMiddleware, allowed_hosts=["*"]) + request = app(scope, receive, send) + await request + + +@pytest.mark.anyio +async def test_trusted_host_in_allowed_hosts(): + async def async_url_for(request): + return PlainTextResponse(request.url_for("async_url_for")) + + app = Starlette( + routes=[ + Route("/func", endpoint=async_url_for), + ], + ) + scope = { + "type": "http", + "headers": { + (b"host", b"testserver"), + }, + "method": "GET", + "path": "/func", + } + + async def receive() -> Message: # pragma: no cover + assert False + + async def send(message): + if message["type"] == "http.response.start": + assert message["status"] == 200 + + app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver"]) + request = app(scope, receive, send) + await request + + +@pytest.mark.anyio +async def test_trusted_host_not_in_allowed_hosts(): + async def async_url_for(request): # pragma: no cover + assert False + + app = Starlette( + routes=[ + Route("/func", endpoint=async_url_for), + ], + ) + scope = { + "type": "http", + "headers": { + (b"host", b"testserver"), + }, + "method": "GET", + "path": "/func", + } + + async def receive() -> Message: # pragma: no cover + assert False + + async def send(message): + if message["type"] == "http.response.start": + assert message["status"] == 400 + elif message["type"] == "http.response.body": + assert message["body"] == b"Invalid host header" + + app.add_middleware(TrustedHostMiddleware, allowed_hosts=["anotherserver"]) + request = app(scope, receive, send) + await request