Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validating host header #1858

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
5 changes: 5 additions & 0 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ class Address(typing.NamedTuple):
_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True)


class TrustedHost(bytes):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"


class URL:
def __init__(
self,
Expand Down
4 changes: 4 additions & 0 deletions starlette/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def __repr__(self) -> str:
return f"{class_name}(code={self.code!r}, reason={self.reason!r})"


class ImproperlyConfigured(Exception):
pass


__deprecated__ = "ExceptionMiddleware"


Expand Down
15 changes: 14 additions & 1 deletion starlette/middleware/trustedhost.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import typing

from starlette.datastructures import URL, Headers
from starlette.datastructures import URL, Headers, TrustedHost
from starlette.responses import PlainTextResponse, RedirectResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send

Expand Down Expand Up @@ -31,6 +32,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"http",
"websocket",
): # pragma: no cover
scope = self._mark_host_header_as_trusted(scope)
await self.app(scope, receive, send)
return

Expand All @@ -48,6 +50,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
found_www_redirect = True

if is_valid_host:
scope = self._mark_host_header_as_trusted(scope)
await self.app(scope, receive, send)
else:
response: Response
Expand All @@ -58,3 +61,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
else:
response = PlainTextResponse("Invalid host header", status_code=400)
await response(scope, receive, send)

def _mark_host_header_as_trusted(self, scope: Scope) -> Scope:
if "headers" not in scope:
return scope
new_scope = copy.copy(scope)
new_scope["headers"] = [
(key, value if key != b"host" else TrustedHost(value))
for key, value in new_scope["headers"]
]
return new_scope
19 changes: 17 additions & 2 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@

import anyio

from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
from starlette.exceptions import HTTPException
from starlette.datastructures import (
URL,
Address,
FormData,
Headers,
QueryParams,
State,
TrustedHost,
)
from starlette.exceptions import HTTPException, ImproperlyConfigured
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
from starlette.types import Message, Receive, Scope, Send

Expand Down Expand Up @@ -103,6 +111,13 @@ def base_url(self) -> URL:
base_url_scope["root_path"] = base_url_scope.get(
"app_root_path", base_url_scope.get("root_path", "")
)
for key, value in base_url_scope["headers"]:
if key == b"host" and not isinstance(value, TrustedHost):
raise ImproperlyConfigured(
"No trusted host header configuration found, you need "
"to use TrustedHostMiddleware(allowed_hosts=[...]) "
"if you want to generate absolute URL."
)
self._base_url = URL(scope=base_url_scope)
return self._base_url

Expand Down
9 changes: 7 additions & 2 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from anyio.streams.stapled import StapledObjectStream

from starlette._utils import is_async_callable
from starlette.datastructures import TrustedHost
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect

Expand Down Expand Up @@ -210,15 +211,19 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
if "host" in request.headers:
headers: typing.List[typing.Tuple[bytes, bytes]] = []
elif port == default_port: # pragma: no cover
headers = [(b"host", host.encode())]
headers = [(b"host", TrustedHost(host.encode()))]
else: # pragma: no cover
headers = [(b"host", (f"{host}:{port}").encode())]
headers = [(b"host", TrustedHost((f"{host}:{port}").encode()))]

# Include other request headers.
headers += [
(key.lower().encode(), value.encode())
for key, value in request.headers.items()
]
headers = [
(key, (value if key != b"host" else TrustedHost(value)))
for key, value in headers
]

scope: typing.Dict[str, typing.Any]

Expand Down
131 changes: 130 additions & 1 deletion tests/test_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from starlette import status
from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint
from starlette.exceptions import HTTPException, WebSocketException
from starlette.exceptions import HTTPException, ImproperlyConfigured, WebSocketException
from starlette.middleware import Middleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.types import Message
from starlette.websockets import WebSocket


Expand Down Expand Up @@ -429,3 +430,131 @@ def lifespan(app):
assert not cleanup_complete
assert startup_complete
assert cleanup_complete


@pytest.mark.anyio
async def test_trusted_host_not_configured():
async def async_url_for(request):
return PlainTextResponse(request.url_for("async_url_for"))

app = Starlette(
routes=[
Route("/func", endpoint=async_url_for),
],
)
scope = {
"type": "http",
"headers": {
(b"host", b"testserver"),
},
"method": "GET",
"path": "/func",
}

async def receive() -> Message: # pragma: no cover
assert False

async def send(message):
if message["type"] == "http.response.start":
assert message["status"] == 500
elif message["type"] == "http.response.body":
assert message["body"] == b"Internal Server Error"

request = app(scope, receive, send)
with pytest.raises(ImproperlyConfigured):
await request


@pytest.mark.anyio
async def test_trusted_host_wildcard():
async def async_url_for(request):
return PlainTextResponse(request.url_for("async_url_for"))

app = Starlette(
routes=[
Route("/func", endpoint=async_url_for),
],
)
scope = {
"type": "http",
"headers": {
(b"host", b"testserver"),
},
"method": "GET",
"path": "/func",
}

async def receive() -> Message: # pragma: no cover
assert False

async def send(message):
if message["type"] == "http.response.start":
assert message["status"] == 200

app.add_middleware(TrustedHostMiddleware, allowed_hosts=["*"])
request = app(scope, receive, send)
await request


@pytest.mark.anyio
async def test_trusted_host_in_allowed_hosts():
async def async_url_for(request):
return PlainTextResponse(request.url_for("async_url_for"))

app = Starlette(
routes=[
Route("/func", endpoint=async_url_for),
],
)
scope = {
"type": "http",
"headers": {
(b"host", b"testserver"),
},
"method": "GET",
"path": "/func",
}

async def receive() -> Message: # pragma: no cover
assert False

async def send(message):
if message["type"] == "http.response.start":
assert message["status"] == 200

app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver"])
request = app(scope, receive, send)
await request


@pytest.mark.anyio
async def test_trusted_host_not_in_allowed_hosts():
async def async_url_for(request): # pragma: no cover
assert False

app = Starlette(
routes=[
Route("/func", endpoint=async_url_for),
],
)
scope = {
"type": "http",
"headers": {
(b"host", b"testserver"),
},
"method": "GET",
"path": "/func",
}

async def receive() -> Message: # pragma: no cover
assert False

async def send(message):
if message["type"] == "http.response.start":
assert message["status"] == 400
elif message["type"] == "http.response.body":
assert message["body"] == b"Invalid host header"

app.add_middleware(TrustedHostMiddleware, allowed_hosts=["anotherserver"])
request = app(scope, receive, send)
await request