From 1d0ac5efdb69ec074b25f9822c7d007e7648bde8 Mon Sep 17 00:00:00 2001 From: Lie Ryan Date: Sat, 10 Sep 2022 00:52:41 +1000 Subject: [PATCH 01/10] Don't allow using `base_url` unless TrustedHostMiddleware is configured See discussion: https://github.com/encode/starlette/discussions/1854 --- starlette/datastructures.py | 5 ++ starlette/exceptions.py | 4 + starlette/middleware/trustedhost.py | 12 ++- starlette/requests.py | 11 ++- tests/test_applications.py | 122 +++++++++++++++++++++++++++- 5 files changed, 150 insertions(+), 4 deletions(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 42ec7a9ea..350e643bb 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -20,6 +20,11 @@ class Address(typing.NamedTuple): _CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True) +class TrustedHost(bytes): + def __repr__(self): + return f"{self.__class__.__name__}({self})" + + class URL: def __init__( self, diff --git a/starlette/exceptions.py b/starlette/exceptions.py index 87da73591..5ffaf15d7 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -33,6 +33,10 @@ def __repr__(self) -> str: return f"{class_name}(code={self.code!r}, reason={self.reason!r})" +class ImproperlyConfigured(Exception): + pass + + __deprecated__ = "ExceptionMiddleware" diff --git a/starlette/middleware/trustedhost.py b/starlette/middleware/trustedhost.py index e84e6876a..9aa5da494 100644 --- a/starlette/middleware/trustedhost.py +++ b/starlette/middleware/trustedhost.py @@ -1,6 +1,6 @@ import typing -from starlette.datastructures import URL, Headers +from starlette.datastructures import URL, Headers, TrustedHost from starlette.responses import PlainTextResponse, RedirectResponse, Response from starlette.types import ASGIApp, Receive, Scope, Send @@ -31,6 +31,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: "http", "websocket", ): # pragma: no cover + self._mark_host_header_as_trusted(scope) await self.app(scope, receive, send) return @@ -48,6 +49,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: found_www_redirect = True if is_valid_host: + self._mark_host_header_as_trusted(scope) await self.app(scope, receive, send) else: response: Response @@ -58,3 +60,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: else: response = PlainTextResponse("Invalid host header", status_code=400) await response(scope, receive, send) + + def _mark_host_header_as_trusted(self, scope): + if "headers" not in scope: + return + scope["headers"] = [ + (key, value if key != b"host" else TrustedHost(value)) + for key, value in scope["headers"] + ] diff --git a/starlette/requests.py b/starlette/requests.py index 726abddcc..c32919e8b 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -4,8 +4,8 @@ import anyio -from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State -from starlette.exceptions import HTTPException +from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State, TrustedHost +from starlette.exceptions import HTTPException, ImproperlyConfigured from starlette.formparsers import FormParser, MultiPartException, MultiPartParser from starlette.types import Message, Receive, Scope, Send @@ -103,6 +103,13 @@ def base_url(self) -> URL: base_url_scope["root_path"] = base_url_scope.get( "app_root_path", base_url_scope.get("root_path", "") ) + for key, value in base_url_scope["headers"]: + if key == b"host" and not isinstance(value, TrustedHost): + raise ImproperlyConfigured( + "No trusted host header configuration found, you need " + "to use TrustedHostMiddleware(allowed_hosts=[...]) " + "if you want to generate absolute URL." + ) self._base_url = URL(scope=base_url_scope) return self._base_url diff --git a/tests/test_applications.py b/tests/test_applications.py index 2cee601b0..938514cbb 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -7,7 +7,7 @@ from starlette import status from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint -from starlette.exceptions import HTTPException, WebSocketException +from starlette.exceptions import HTTPException, WebSocketException, ImproperlyConfigured from starlette.middleware import Middleware from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.responses import JSONResponse, PlainTextResponse @@ -429,3 +429,123 @@ def lifespan(app): assert not cleanup_complete assert startup_complete assert cleanup_complete + + +@pytest.mark.anyio +async def test_trusted_host_not_configured(): + async def async_url_for(request): + return PlainTextResponse(request.url_for("async_url_for")) + + app = Starlette( + routes=[ + Route("/func", endpoint=async_url_for), + ], + ) + scope = { + "type": "http", + "headers": { + (b"host", b"testserver"), + }, + "method": "GET", + "path": "/func", + } + async def receive(arg): + pass + async def send(arg): + if arg["type"] == "http.response.start": + assert arg["status"] == 500 + elif arg["type"] == "http.response.body": + assert arg["body"] == b"Internal Server Error" + + request = app(scope, receive, send) + with pytest.raises(ImproperlyConfigured): + await request + + +@pytest.mark.anyio +async def test_trusted_host_wildcard(): + async def async_url_for(request): + return PlainTextResponse(request.url_for("async_url_for")) + + app = Starlette( + routes=[ + Route("/func", endpoint=async_url_for), + ], + ) + scope = { + "type": "http", + "headers": { + (b"host", b"testserver"), + }, + "method": "GET", + "path": "/func", + } + async def receive(arg): + pass + async def send(arg): + if arg["type"] == "http.response.start": + assert arg["status"] == 200 + + app.add_middleware(TrustedHostMiddleware, allowed_hosts=["*"]) + request = app(scope, receive, send) + await request + + +@pytest.mark.anyio +async def test_trusted_host_in_allowed_hosts(): + async def async_url_for(request): + return PlainTextResponse(request.url_for("async_url_for")) + + app = Starlette( + routes=[ + Route("/func", endpoint=async_url_for), + ], + ) + scope = { + "type": "http", + "headers": { + (b"host", b"testserver"), + }, + "method": "GET", + "path": "/func", + } + async def receive(arg): + pass + async def send(arg): + if arg["type"] == "http.response.start": + assert arg["status"] == 200 + + app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver"]) + request = app(scope, receive, send) + await request + + +@pytest.mark.anyio +async def test_trusted_host_not_in_allowed_hosts(): + async def async_url_for(request): + return PlainTextResponse(request.url_for("async_url_for")) + + app = Starlette( + routes=[ + Route("/func", endpoint=async_url_for), + ], + ) + scope = { + "type": "http", + "headers": { + (b"host", b"testserver"), + }, + "method": "GET", + "path": "/func", + } + async def receive(arg): + pass + async def send(arg): + if arg["type"] == "http.response.start": + assert arg["status"] == 400 + elif arg["type"] == "http.response.body": + assert arg["body"] == b"Invalid host header" + + app.add_middleware(TrustedHostMiddleware, allowed_hosts=["anotherserver"]) + request = app(scope, receive, send) + resp = await request From 7d35f52c074cdebbe9b2dddb51fabda4d89f92c1 Mon Sep 17 00:00:00 2001 From: Lie Ryan Date: Sat, 10 Sep 2022 00:54:01 +1000 Subject: [PATCH 02/10] Modify the TestClient to not require explicit TrustedHostMiddleware configuration --- starlette/testclient.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 455440ce5..a0c2e274c 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -16,6 +16,7 @@ from anyio.streams.stapled import StapledObjectStream from starlette._utils import is_async_callable +from starlette.datastructures import TrustedHost from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect @@ -210,15 +211,19 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: if "host" in request.headers: headers: typing.List[typing.Tuple[bytes, bytes]] = [] elif port == default_port: # pragma: no cover - headers = [(b"host", host.encode())] + headers = [(b"host", TrustedHost(host.encode()))] else: # pragma: no cover - headers = [(b"host", (f"{host}:{port}").encode())] + headers = [(b"host", TrustedHost((f"{host}:{port}").encode()))] # Include other request headers. headers += [ (key.lower().encode(), value.encode()) for key, value in request.headers.items() ] + headers = [ + (key, (value if key != b"host" else TrustedHost(value))) + for key, value in headers + ] scope: typing.Dict[str, typing.Any] From 51e86280c9bdb538dba36ad590dd1dc319084389 Mon Sep 17 00:00:00 2001 From: Lie Ryan Date: Mon, 12 Sep 2022 09:33:37 +1000 Subject: [PATCH 03/10] Copy the scope instead of modifying the existing one --- starlette/middleware/trustedhost.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/starlette/middleware/trustedhost.py b/starlette/middleware/trustedhost.py index 9aa5da494..373b8f6e3 100644 --- a/starlette/middleware/trustedhost.py +++ b/starlette/middleware/trustedhost.py @@ -31,7 +31,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: "http", "websocket", ): # pragma: no cover - self._mark_host_header_as_trusted(scope) + scope = self._mark_host_header_as_trusted(scope) await self.app(scope, receive, send) return @@ -49,7 +49,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: found_www_redirect = True if is_valid_host: - self._mark_host_header_as_trusted(scope) + scope = self._mark_host_header_as_trusted(scope) await self.app(scope, receive, send) else: response: Response @@ -63,8 +63,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: def _mark_host_header_as_trusted(self, scope): if "headers" not in scope: - return - scope["headers"] = [ + return scope + new_scope = scope.copy() + new_scope["headers"] = [ (key, value if key != b"host" else TrustedHost(value)) - for key, value in scope["headers"] + for key, value in new_scope["headers"] ] + return new_scope From 8e5a4a81dbeb7821a05c9c2a232a3cf638105c6c Mon Sep 17 00:00:00 2001 From: Lie Ryan Date: Mon, 12 Sep 2022 10:24:58 +1000 Subject: [PATCH 04/10] Fix linter issues --- starlette/requests.py | 10 ++++++++- tests/test_applications.py | 42 +++++++++++++++++++------------------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/starlette/requests.py b/starlette/requests.py index c32919e8b..b6231cc63 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -4,7 +4,15 @@ import anyio -from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State, TrustedHost +from starlette.datastructures import ( + URL, + Address, + FormData, + Headers, + QueryParams, + State, + TrustedHost, +) from starlette.exceptions import HTTPException, ImproperlyConfigured from starlette.formparsers import FormParser, MultiPartException, MultiPartParser from starlette.types import Message, Receive, Scope, Send diff --git a/tests/test_applications.py b/tests/test_applications.py index 938514cbb..815fbb1ef 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -7,7 +7,7 @@ from starlette import status from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint -from starlette.exceptions import HTTPException, WebSocketException, ImproperlyConfigured +from starlette.exceptions import HTTPException, ImproperlyConfigured, WebSocketException from starlette.middleware import Middleware from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.responses import JSONResponse, PlainTextResponse @@ -449,13 +449,13 @@ async def async_url_for(request): "method": "GET", "path": "/func", } - async def receive(arg): + async def receive(message): pass - async def send(arg): - if arg["type"] == "http.response.start": - assert arg["status"] == 500 - elif arg["type"] == "http.response.body": - assert arg["body"] == b"Internal Server Error" + async def send(message): + if message["type"] == "http.response.start": + assert message["status"] == 500 + elif message["type"] == "http.response.body": + assert message["body"] == b"Internal Server Error" request = app(scope, receive, send) with pytest.raises(ImproperlyConfigured): @@ -480,11 +480,11 @@ async def async_url_for(request): "method": "GET", "path": "/func", } - async def receive(arg): + async def receive(message): pass - async def send(arg): - if arg["type"] == "http.response.start": - assert arg["status"] == 200 + async def send(message): + if message["type"] == "http.response.start": + assert message["status"] == 200 app.add_middleware(TrustedHostMiddleware, allowed_hosts=["*"]) request = app(scope, receive, send) @@ -509,11 +509,11 @@ async def async_url_for(request): "method": "GET", "path": "/func", } - async def receive(arg): + async def receive(message): pass - async def send(arg): - if arg["type"] == "http.response.start": - assert arg["status"] == 200 + async def send(message): + if message["type"] == "http.response.start": + assert message["status"] == 200 app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver"]) request = app(scope, receive, send) @@ -538,13 +538,13 @@ async def async_url_for(request): "method": "GET", "path": "/func", } - async def receive(arg): + async def receive(message): pass - async def send(arg): - if arg["type"] == "http.response.start": - assert arg["status"] == 400 - elif arg["type"] == "http.response.body": - assert arg["body"] == b"Invalid host header" + async def send(message): + if message["type"] == "http.response.start": + assert message["status"] == 400 + elif message["type"] == "http.response.body": + assert message["body"] == b"Invalid host header" app.add_middleware(TrustedHostMiddleware, allowed_hosts=["anotherserver"]) request = app(scope, receive, send) From a9eed46f57d33f6b530b609ccd231696a7ae88a7 Mon Sep 17 00:00:00 2001 From: Lie Ryan Date: Mon, 12 Sep 2022 10:26:48 +1000 Subject: [PATCH 05/10] Black --- tests/test_applications.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_applications.py b/tests/test_applications.py index 815fbb1ef..54d4664e6 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -449,8 +449,10 @@ async def async_url_for(request): "method": "GET", "path": "/func", } + async def receive(message): pass + async def send(message): if message["type"] == "http.response.start": assert message["status"] == 500 @@ -480,8 +482,10 @@ async def async_url_for(request): "method": "GET", "path": "/func", } + async def receive(message): pass + async def send(message): if message["type"] == "http.response.start": assert message["status"] == 200 @@ -509,8 +513,10 @@ async def async_url_for(request): "method": "GET", "path": "/func", } + async def receive(message): pass + async def send(message): if message["type"] == "http.response.start": assert message["status"] == 200 @@ -538,8 +544,10 @@ async def async_url_for(request): "method": "GET", "path": "/func", } + async def receive(message): pass + async def send(message): if message["type"] == "http.response.start": assert message["status"] == 400 From 55dfb8e5998d310ca1b116978f571f507bc923ad Mon Sep 17 00:00:00 2001 From: Lie Ryan Date: Mon, 12 Sep 2022 10:28:40 +1000 Subject: [PATCH 06/10] Fix flake8 issue --- tests/test_applications.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_applications.py b/tests/test_applications.py index 54d4664e6..e5aad465b 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -556,4 +556,4 @@ async def send(message): app.add_middleware(TrustedHostMiddleware, allowed_hosts=["anotherserver"]) request = app(scope, receive, send) - resp = await request + await request From 346ab07ed1b8c1a58e1ffe2e494600d7b5407446 Mon Sep 17 00:00:00 2001 From: Lie Ryan Date: Tue, 13 Sep 2022 15:45:32 +1000 Subject: [PATCH 07/10] Fix type errors raised by mypy --- starlette/datastructures.py | 4 ++-- starlette/middleware/trustedhost.py | 5 +++-- tests/test_applications.py | 9 +++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 350e643bb..14892b690 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -21,8 +21,8 @@ class Address(typing.NamedTuple): class TrustedHost(bytes): - def __repr__(self): - return f"{self.__class__.__name__}({self})" + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self!r})" class URL: diff --git a/starlette/middleware/trustedhost.py b/starlette/middleware/trustedhost.py index 373b8f6e3..1f9a8927c 100644 --- a/starlette/middleware/trustedhost.py +++ b/starlette/middleware/trustedhost.py @@ -1,3 +1,4 @@ +import copy import typing from starlette.datastructures import URL, Headers, TrustedHost @@ -61,10 +62,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: response = PlainTextResponse("Invalid host header", status_code=400) await response(scope, receive, send) - def _mark_host_header_as_trusted(self, scope): + def _mark_host_header_as_trusted(self, scope: Scope) -> Scope: if "headers" not in scope: return scope - new_scope = scope.copy() + new_scope = copy.copy(scope) new_scope["headers"] = [ (key, value if key != b"host" else TrustedHost(value)) for key, value in new_scope["headers"] diff --git a/tests/test_applications.py b/tests/test_applications.py index e5aad465b..df14a6163 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -13,6 +13,7 @@ from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles +from starlette.types import Message from starlette.websockets import WebSocket @@ -450,7 +451,7 @@ async def async_url_for(request): "path": "/func", } - async def receive(message): + async def receive() -> Message: pass async def send(message): @@ -483,7 +484,7 @@ async def async_url_for(request): "path": "/func", } - async def receive(message): + async def receive() -> Message: pass async def send(message): @@ -514,7 +515,7 @@ async def async_url_for(request): "path": "/func", } - async def receive(message): + async def receive() -> Message: pass async def send(message): @@ -545,7 +546,7 @@ async def async_url_for(request): "path": "/func", } - async def receive(message): + async def receive() -> Message: pass async def send(message): From 21877a7cca7ff2141fec93ba442a297fb3860051 Mon Sep 17 00:00:00 2001 From: Lie Ryan Date: Tue, 13 Sep 2022 15:56:19 +1000 Subject: [PATCH 08/10] Fix repr of TrustedHost --- starlette/datastructures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 14892b690..334272a72 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -22,7 +22,7 @@ class Address(typing.NamedTuple): class TrustedHost(bytes): def __repr__(self) -> str: - return f"{self.__class__.__name__}({self!r})" + return f"{self.__class__.__name__}({super().__repr__()})" class URL: From 81e7a9df4f3e8f07d4060eef7176705f7ac1cdfd Mon Sep 17 00:00:00 2001 From: Lie Ryan Date: Tue, 13 Sep 2022 16:10:18 +1000 Subject: [PATCH 09/10] Fix coverage --- tests/test_applications.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/test_applications.py b/tests/test_applications.py index df14a6163..f2f5cdd8b 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -451,8 +451,8 @@ async def async_url_for(request): "path": "/func", } - async def receive() -> Message: - pass + async def receive() -> Message: # pragma: no cover + assert False async def send(message): if message["type"] == "http.response.start": @@ -484,8 +484,8 @@ async def async_url_for(request): "path": "/func", } - async def receive() -> Message: - pass + async def receive() -> Message: # pragma: no cover + assert False async def send(message): if message["type"] == "http.response.start": @@ -515,8 +515,8 @@ async def async_url_for(request): "path": "/func", } - async def receive() -> Message: - pass + async def receive() -> Message: # pragma: no cover + assert False async def send(message): if message["type"] == "http.response.start": @@ -529,8 +529,8 @@ async def send(message): @pytest.mark.anyio async def test_trusted_host_not_in_allowed_hosts(): - async def async_url_for(request): - return PlainTextResponse(request.url_for("async_url_for")) + async def async_url_for(request): # pragma: no cover + assert False app = Starlette( routes=[ @@ -546,8 +546,8 @@ async def async_url_for(request): "path": "/func", } - async def receive() -> Message: - pass + async def receive() -> Message: # pragma: no cover + assert False async def send(message): if message["type"] == "http.response.start": From 0f8191ff2b1962b6efe41017f121af2083cc9307 Mon Sep 17 00:00:00 2001 From: Lie Ryan Date: Tue, 13 Sep 2022 16:12:16 +1000 Subject: [PATCH 10/10] Black --- tests/test_applications.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_applications.py b/tests/test_applications.py index f2f5cdd8b..e673b557e 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -529,7 +529,7 @@ async def send(message): @pytest.mark.anyio async def test_trusted_host_not_in_allowed_hosts(): - async def async_url_for(request): # pragma: no cover + async def async_url_for(request): # pragma: no cover assert False app = Starlette(