diff --git a/starlette/testclient.py b/starlette/testclient.py index 1d5e90dc8..c146eed3d 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -10,6 +10,7 @@ from urllib.parse import unquote, urljoin, urlsplit import requests +from urllib3.util.timeout import Timeout from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect @@ -96,7 +97,11 @@ def __init__( self.root_path = root_path def send( - self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any + self, + request: requests.PreparedRequest, + *args: typing.Any, + timeout: Timeout = None, + **kwargs: typing.Any, ) -> requests.Response: scheme, netloc, path, query, fragment = ( str(item) for item in urlsplit(request.url) @@ -237,7 +242,15 @@ async def send(message: Message) -> None: asyncio.set_event_loop(loop) try: - loop.run_until_complete(self.app(scope, receive, send)) + if isinstance(timeout, tuple): + err = ( + "Invalid timeout {}. testclient only supports float (not tuple)" + "at this time ".format(timeout) + ) + raise ValueError(err) + loop.run_until_complete( + asyncio.wait_for(self.app(scope, receive, send), timeout) + ) except BaseException as exc: if self.raise_server_exceptions: raise exc from None diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 00f4e0125..d702ad195 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -1,4 +1,5 @@ import asyncio +import time import pytest @@ -16,6 +17,19 @@ def mock_service_endpoint(request): return JSONResponse({"mock": "example"}) +@mock_service.route("/slow_response") +def slow_response(request): + time.sleep(0.01) + return JSONResponse({"mock": "slow example"}) + + +@mock_service.route("/async_slow_response") +async def async_slow_response(request): + # time.sleep(0.01) + await asyncio.sleep(0.01) + return JSONResponse({"mock": "slow example"}) + + app = Starlette() @@ -132,3 +146,34 @@ async def asgi(receive, send): with client.websocket_connect("/") as websocket: data = websocket.receive_json() assert data == {"message": "test"} + + +@pytest.mark.parametrize("endpoint", ["/slow_response", "/async_slow_response"]) +def test_timeout(endpoint): + client = TestClient(mock_service, raise_server_exceptions=True) + + with pytest.raises(ValueError): + client.get(endpoint, timeout=(1, 1)) + + with pytest.raises(asyncio.TimeoutError): + client.get(endpoint, timeout=0.001) + + response = client.get(endpoint, timeout=1) + assert response.json() == {"mock": "slow example"} + + response = client.get(endpoint) + assert response.json() == {"mock": "slow example"} + + client = TestClient(mock_service, raise_server_exceptions=False) + + response = client.get(endpoint, timeout=(1, 1)) + assert response.status_code == 500 + + response = client.get(endpoint, timeout=0.001) + assert response.status_code == 500 + + response = client.get(endpoint, timeout=1) + assert response.json() == {"mock": "slow example"} + + response = client.get(endpoint) + assert response.json() == {"mock": "slow example"}