diff --git a/python/lsst/daf/butler/remote_butler/_http_connection.py b/python/lsst/daf/butler/remote_butler/_http_connection.py index 0116b630b7..c2f2c9c5aa 100644 --- a/python/lsst/daf/butler/remote_butler/_http_connection.py +++ b/python/lsst/daf/butler/remote_butler/_http_connection.py @@ -29,6 +29,7 @@ __all__ = ("RemoteButlerHttpConnection", "parse_model") +import time import urllib.parse from collections.abc import Iterator, Mapping from contextlib import contextmanager @@ -210,7 +211,7 @@ def _send_request(self, request: _Request) -> httpx.Response: with the message as a subclass of ButlerUserError. """ try: - response = self._client.send(request.request) + response = self._send_with_retries(request, stream=False) self._handle_http_status(response, request.request_id) return response except httpx.HTTPError as e: @@ -219,7 +220,7 @@ def _send_request(self, request: _Request) -> httpx.Response: @contextmanager def _send_request_with_stream_response(self, request: _Request) -> Iterator[httpx.Response]: try: - response = self._client.send(request.request, stream=True) + response = self._send_with_retries(request, stream=True) try: self._handle_http_status(response, request.request_id) yield response @@ -228,6 +229,17 @@ def _send_request_with_stream_response(self, request: _Request) -> Iterator[http except httpx.HTTPError as e: raise ButlerServerError(request.request_id) from e + def _send_with_retries(self, request: _Request, stream: bool) -> httpx.Response: + while True: + response = self._client.send(request.request, stream=stream) + retry = _needs_retry(response) + if retry.retry: + if stream: + response.close() + time.sleep(retry.delay_seconds) + else: + return response + def _handle_http_status(self, response: httpx.Response, request_id: str) -> None: if response.status_code == ERROR_STATUS_CODE: # Raise an exception that the server has forwarded to the @@ -245,6 +257,32 @@ def _handle_http_status(self, response: httpx.Response, request_id: str) -> None response.raise_for_status() +@dataclass(frozen=True) +class _Retry: + retry: bool + delay_seconds: int + + +def _needs_retry(response: httpx.Response) -> _Retry: + # Handle a 503 Service Unavailable, sent by the server if it is + # overloaded, or a 429, sent by the server if the client + # triggers a rate limit. + if response.status_code == 503 or response.status_code == 429: + # Only retry if the server has instructed us to do so by sending a + # Retry-After header. + retry_after = response.headers.get("retry-after") + if retry_after is not None: + try: + # The HTTP standard also allows a date string here, but the + # Butler server only sends integer seconds. + delay_seconds = int(retry_after) + return _Retry(True, delay_seconds) + except ValueError: + pass + + return _Retry(False, 0) + + def parse_model(response: httpx.Response, model: type[_AnyPydanticModel]) -> _AnyPydanticModel: """Deserialize a Pydantic model from the body of an HTTP response. diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py index 0a2118af19..159994446c 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py @@ -91,7 +91,7 @@ async def query_execute( request: QueryExecuteRequestModel, factory: Factory = Depends(factory_dependency) ) -> StreamingResponse: query = _StreamQueryDriverExecute(request, factory) - return execute_streaming_query(query) + return await execute_streaming_query(query) class _QueryAllDatasetsContext(NamedTuple): @@ -136,7 +136,7 @@ async def query_all_datasets_execute( request: QueryAllDatasetsRequestModel, factory: Factory = Depends(factory_dependency) ) -> StreamingResponse: query = _StreamQueryAllDatasets(request, factory) - return execute_streaming_query(query) + return await execute_streaming_query(query) @query_router.post( diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_query_streaming.py b/python/lsst/daf/butler/remote_butler/server/handlers/_query_streaming.py index 04ef3a9c32..7e854ae789 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_query_streaming.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_query_streaming.py @@ -32,6 +32,7 @@ from contextlib import AbstractContextManager from typing import Protocol, TypeVar +from fastapi import HTTPException from fastapi.concurrency import contextmanager_in_threadpool, iterate_in_threadpool from fastapi.responses import StreamingResponse from lsst.daf.butler.remote_butler.server_models import ( @@ -43,11 +44,26 @@ from ...._exceptions import ButlerUserError from ..._errors import serialize_butler_user_error +# Restrict the maximum number of streaming queries that can be running +# simultaneously, to prevent the database connection pool and the thread pool +# from being tied up indefinitely. Beyond this number, the server will return +# an HTTP 503 Service Unavailable with a Retry-After header. We are currently +# using the default FastAPI thread pool size of 40 (total) and have 40 maximum +# database connections (per Butler repository.) +_MAXIMUM_CONCURRENT_STREAMING_QUERIES = 25 +# How long we ask callers to wait before trying their query again. +# The hope is that they will bounce to a less busy replica, so we don't want +# them to wait too long. +_QUERY_RETRY_SECONDS = 5 + # Alias this function so we can mock it during unit tests. _timeout = asyncio.timeout _TContext = TypeVar("_TContext") +# Count of active streaming queries. +_current_streaming_queries = 0 + class StreamingQuery(Protocol[_TContext]): """Interface for queries that can return streaming results.""" @@ -67,7 +83,7 @@ def execute(self, context: _TContext) -> Iterator[QueryExecuteResultData]: """ -def execute_streaming_query(query: StreamingQuery) -> StreamingResponse: +async def execute_streaming_query(query: StreamingQuery) -> StreamingResponse: """Run a query, streaming the response incrementally, one page at a time, as newline-separated chunks of JSON. @@ -95,6 +111,23 @@ def execute_streaming_query(query: StreamingQuery) -> StreamingResponse: read -- ``StreamingQuery.execute()`` cannot be interrupted while it is in the middle of reading a page. """ + # Prevent an excessive number of streaming queries from jamming up the + # thread pool and database connection pool. We can't change the response + # code after starting the StreamingResponse, so we enforce this here. + # + # This creates a small chance that more than the expected number of + # streaming queries will be started, but there is no guarantee that the + # StreamingResponse generator function will ever be called, so we can't + # guarantee that we release the slot if we reserve one here. + if _current_streaming_queries >= _MAXIMUM_CONCURRENT_STREAMING_QUERIES: + await _block_retry_for_unit_test() + raise HTTPException( + status_code=503, # service temporarily unavailable + detail="The Butler Server is currently overloaded with requests." + f" Try again in {_QUERY_RETRY_SECONDS} seconds.", + headers={"retry-after": str(_QUERY_RETRY_SECONDS)}, + ) + output_generator = _stream_query_pages(query) return StreamingResponse( output_generator, @@ -115,17 +148,24 @@ async def _stream_query_pages(query: StreamingQuery) -> AsyncIterator[str]: When it takes longer than 15 seconds to get a response from the DB, sends a keep-alive message to prevent clients from timing out. """ - # `None` signals that there is no more data to send. - queue = asyncio.Queue[QueryExecuteResultData | None](1) - async with asyncio.TaskGroup() as tg: - # Run a background task to read from the DB and insert the result pages - # into a queue. - tg.create_task(_enqueue_query_pages(queue, query)) - # Read the result pages from the queue and send them to the client, - # inserting a keep-alive message every 15 seconds if we are waiting a - # long time for the database. - async for message in _dequeue_query_pages_with_keepalive(queue): - yield message.model_dump_json() + "\n" + global _current_streaming_queries + try: + _current_streaming_queries += 1 + await _block_query_for_unit_test() + + # `None` signals that there is no more data to send. + queue = asyncio.Queue[QueryExecuteResultData | None](1) + async with asyncio.TaskGroup() as tg: + # Run a background task to read from the DB and insert the result + # pages into a queue. + tg.create_task(_enqueue_query_pages(queue, query)) + # Read the result pages from the queue and send them to the client, + # inserting a keep-alive message every 15 seconds if we are waiting + # a long time for the database. + async for message in _dequeue_query_pages_with_keepalive(queue): + yield message.model_dump_json() + "\n" + finally: + _current_streaming_queries -= 1 async def _enqueue_query_pages( @@ -163,3 +203,17 @@ async def _dequeue_query_pages_with_keepalive( yield message except TimeoutError: yield QueryKeepAliveModel() + + +async def _block_retry_for_unit_test() -> None: + """Will be overridden during unit tests to block the server, + in order to verify retry logic. + """ + pass + + +async def _block_query_for_unit_test() -> None: + """Will be overridden during unit tests to block the server, + in order to verify maximum concurrency logic. + """ + pass diff --git a/python/lsst/daf/butler/tests/server.py b/python/lsst/daf/butler/tests/server.py index ba89d5dc2c..1146f8acfb 100644 --- a/python/lsst/daf/butler/tests/server.py +++ b/python/lsst/daf/butler/tests/server.py @@ -122,26 +122,29 @@ def create_test_server( server_butler_factory._preload_direct_butler_cache = False app.dependency_overrides[butler_factory_dependency] = lambda: server_butler_factory - client = TestClient(app) - client_without_error_propagation = TestClient(app, raise_server_exceptions=False) - - remote_butler = _make_remote_butler(client) - remote_butler_without_error_propagation = _make_remote_butler( - client_without_error_propagation - ) - - direct_butler = Butler.from_config(config_file_path, writeable=True) - assert isinstance(direct_butler, DirectButler) - hybrid_butler = HybridButler(remote_butler, direct_butler) - - yield TestServerInstance( - config_file_path=config_file_path, - client=client, - direct_butler=direct_butler, - remote_butler=remote_butler, - remote_butler_without_error_propagation=remote_butler_without_error_propagation, - hybrid_butler=hybrid_butler, - ) + # Using TestClient in a context manager ensures that it uses + # the same async event loop for all requests -- otherwise it + # starts a new one on each request. + with TestClient(app) as client: + remote_butler = _make_remote_butler(client) + + direct_butler = Butler.from_config(config_file_path, writeable=True) + assert isinstance(direct_butler, DirectButler) + hybrid_butler = HybridButler(remote_butler, direct_butler) + + client_without_error_propagation = TestClient(app, raise_server_exceptions=False) + remote_butler_without_error_propagation = _make_remote_butler( + client_without_error_propagation + ) + + yield TestServerInstance( + config_file_path=config_file_path, + client=client, + direct_butler=direct_butler, + remote_butler=remote_butler, + remote_butler_without_error_propagation=remote_butler_without_error_propagation, + hybrid_butler=hybrid_butler, + ) def _make_remote_butler(client: TestClient) -> RemoteButler: diff --git a/tests/test_server.py b/tests/test_server.py index efc8ad8b56..236faeadd1 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -25,10 +25,14 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import asyncio import os.path import tempfile +import threading import unittest import uuid +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import DEFAULT, AsyncMock, NonCallableMock, patch from lsst.daf.butler.tests.dict_convertible_model import DictConvertibleModel @@ -50,7 +54,6 @@ create_test_server = None reason_text = str(e) -from unittest.mock import DEFAULT, NonCallableMock, patch from lsst.daf.butler import ( Butler, @@ -427,6 +430,76 @@ def test_query_keepalive(self): self.assertGreaterEqual(mock_timeout.call_count, 3) self.assertGreaterEqual(mock_keep_alive.call_count, 2) + @patch( + "lsst.daf.butler.remote_butler.server.handlers._query_streaming._MAXIMUM_CONCURRENT_STREAMING_QUERIES", + 1, + ) + @patch("lsst.daf.butler.remote_butler.server.handlers._query_streaming._QUERY_RETRY_SECONDS", 1) + def test_query_retries(self): + """Test that the server will send HTTP status 503 to put backpressure + on clients if it is overloaded, and that the client will retry if this + happens. + """ + query_event = threading.Event() + retry_event = asyncio.Event() + + async def block_first_request() -> None: + # Signal the unit tests that we have reached the critical section + # in the server, where the first client has reserved the query + # slot. + query_event.set() + # Block inside the query, until the 2nd client has been forced to + # retry. + await retry_event.wait() + + async def block_second_request() -> None: + # Release the first client, so it can finish its query and prevent + # this client from being blocked on the next go-round. + retry_event.set() + + def do_query(butler: Butler) -> list[DatasetRef]: + return butler.query_datasets("bias", "imported_g") + + with ( + patch.object( + lsst.daf.butler.remote_butler.server.handlers._query_streaming, + "_block_query_for_unit_test", + new=AsyncMock(wraps=block_first_request), + ) as mock_first_client, + patch.object( + lsst.daf.butler.remote_butler.server.handlers._query_streaming, + "_block_retry_for_unit_test", + new=AsyncMock(wraps=block_second_request), + ) as mock_second_client, + ThreadPoolExecutor(max_workers=1) as exec1, + ThreadPoolExecutor(max_workers=1) as exec2, + ): + first_butler = self.butler + second_butler = self.butler.clone() + + # Run the first client up until the server starts executing its + # query. + future1 = exec1.submit(do_query, first_butler) + event_reached = query_event.wait(60) + if not event_reached: + raise TimeoutError("Server did not execute query logic as expected.") + + # Start the second client, which will trigger the retry logic and + # release the first client to finish its query. + future2 = exec2.submit(do_query, second_butler) + + result1 = future1.result(60) + result2 = future2.result(60) + self.assertEqual(len(result1), 3) + self.assertEqual(len(result2), 3) + # The original thread should have gone through this section, and + # then the 2nd thread after it retries. + self.assertEqual(mock_first_client.await_count, 2) + # We should have triggered the retry logic at least once, but it + # might occur multiple times depending how long the first client + # takes to finish. + self.assertGreaterEqual(mock_second_client.await_count, 1) + # TODO DM-46204: This can be removed once the RSP recommended image has # been upgraded to a version that contains DM-46129. def test_deprecated_collection_endpoints(self):