Skip to content

Commit

Permalink
Limit the number of in-flight queries on the server
Browse files Browse the repository at this point in the history
The server now tracks the number of streaming queries in progress, and rejects query requests that exceed a limit of 25 with an HTTP 503.  The client now retries on receiving a 503 or 429 with a Retry-After header.

This will allow the server to avoid exhausting its thread pool and database connection pool with long-running queries.
  • Loading branch information
dhirving committed Dec 3, 2024
1 parent 919ad6a commit 3f41ce9
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 37 deletions.
42 changes: 40 additions & 2 deletions python/lsst/daf/butler/remote_butler/_http_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

__all__ = ("RemoteButlerHttpConnection", "parse_model")

import time
import urllib.parse
from collections.abc import Iterator, Mapping
from contextlib import contextmanager
Expand Down Expand Up @@ -210,7 +211,7 @@ def _send_request(self, request: _Request) -> httpx.Response:
with the message as a subclass of ButlerUserError.
"""
try:
response = self._client.send(request.request)
response = self._send_with_retries(request, stream=False)
self._handle_http_status(response, request.request_id)
return response
except httpx.HTTPError as e:
Expand All @@ -219,7 +220,7 @@ def _send_request(self, request: _Request) -> httpx.Response:
@contextmanager
def _send_request_with_stream_response(self, request: _Request) -> Iterator[httpx.Response]:
try:
response = self._client.send(request.request, stream=True)
response = self._send_with_retries(request, stream=True)
try:
self._handle_http_status(response, request.request_id)
yield response
Expand All @@ -228,6 +229,17 @@ def _send_request_with_stream_response(self, request: _Request) -> Iterator[http
except httpx.HTTPError as e:
raise ButlerServerError(request.request_id) from e

def _send_with_retries(self, request: _Request, stream: bool) -> httpx.Response:
while True:
response = self._client.send(request.request, stream=stream)
retry = _needs_retry(response)
if retry.retry:
if stream:
response.close()
time.sleep(retry.delay_seconds)
else:
return response

def _handle_http_status(self, response: httpx.Response, request_id: str) -> None:
if response.status_code == ERROR_STATUS_CODE:
# Raise an exception that the server has forwarded to the
Expand All @@ -245,6 +257,32 @@ def _handle_http_status(self, response: httpx.Response, request_id: str) -> None
response.raise_for_status()


@dataclass(frozen=True)
class _Retry:
retry: bool
delay_seconds: int


def _needs_retry(response: httpx.Response) -> _Retry:
# Handle a 503 Service Unavailable, sent by the server if it is
# overloaded, or a 429, sent by the server if the client
# triggers a rate limit.
if response.status_code == 503 or response.status_code == 429:
# Only retry if the server has instructed us to do so by sending a
# Retry-After header.
retry_after = response.headers.get("retry-after")
if retry_after is not None:
try:
# The HTTP standard also allows a date string here, but the
# Butler server only sends integer seconds.
delay_seconds = int(retry_after)
return _Retry(True, delay_seconds)
except ValueError:
pass

return _Retry(False, 0)


def parse_model(response: httpx.Response, model: type[_AnyPydanticModel]) -> _AnyPydanticModel:
"""Deserialize a Pydantic model from the body of an HTTP response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def query_execute(
request: QueryExecuteRequestModel, factory: Factory = Depends(factory_dependency)
) -> StreamingResponse:
query = _StreamQueryDriverExecute(request, factory)
return execute_streaming_query(query)
return await execute_streaming_query(query)


class _QueryAllDatasetsContext(NamedTuple):
Expand Down Expand Up @@ -136,7 +136,7 @@ async def query_all_datasets_execute(
request: QueryAllDatasetsRequestModel, factory: Factory = Depends(factory_dependency)
) -> StreamingResponse:
query = _StreamQueryAllDatasets(request, factory)
return execute_streaming_query(query)
return await execute_streaming_query(query)


@query_router.post(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from contextlib import AbstractContextManager
from typing import Protocol, TypeVar

from fastapi import HTTPException
from fastapi.concurrency import contextmanager_in_threadpool, iterate_in_threadpool
from fastapi.responses import StreamingResponse
from lsst.daf.butler.remote_butler.server_models import (
Expand All @@ -43,11 +44,26 @@
from ...._exceptions import ButlerUserError
from ..._errors import serialize_butler_user_error

# Restrict the maximum number of streaming queries that can be running
# simultaneously, to prevent the database connection pool and the thread pool
# from being tied up indefinitely. Beyond this number, the server will return
# an HTTP 503 Service Unavailable with a Retry-After header. We are currently
# using the default FastAPI thread pool size of 40 (total) and have 40 maximum
# database connections (per Butler repository.)
_MAXIMUM_CONCURRENT_STREAMING_QUERIES = 25
# How long we ask callers to wait before trying their query again.
# The hope is that they will bounce to a less busy replica, so we don't want
# them to wait too long.
_QUERY_RETRY_SECONDS = 5

# Alias this function so we can mock it during unit tests.
_timeout = asyncio.timeout

_TContext = TypeVar("_TContext")

# Count of active streaming queries.
_current_streaming_queries = 0


class StreamingQuery(Protocol[_TContext]):
"""Interface for queries that can return streaming results."""
Expand All @@ -67,7 +83,7 @@ def execute(self, context: _TContext) -> Iterator[QueryExecuteResultData]:
"""


def execute_streaming_query(query: StreamingQuery) -> StreamingResponse:
async def execute_streaming_query(query: StreamingQuery) -> StreamingResponse:
"""Run a query, streaming the response incrementally, one page at a time,
as newline-separated chunks of JSON.
Expand Down Expand Up @@ -95,6 +111,23 @@ def execute_streaming_query(query: StreamingQuery) -> StreamingResponse:
read -- ``StreamingQuery.execute()`` cannot be interrupted while it is
in the middle of reading a page.
"""
# Prevent an excessive number of streaming queries from jamming up the
# thread pool and database connection pool. We can't change the response
# code after starting the StreamingResponse, so we enforce this here.
#
# This creates a small chance that more than the expected number of
# streaming queries will be started, but there is no guarantee that the
# StreamingResponse generator function will ever be called, so we can't
# guarantee that we release the slot if we reserve one here.
if _current_streaming_queries >= _MAXIMUM_CONCURRENT_STREAMING_QUERIES:
await _block_retry_for_unit_test()
raise HTTPException(
status_code=503, # service temporarily unavailable
detail="The Butler Server is currently overloaded with requests."
f" Try again in {_QUERY_RETRY_SECONDS} seconds.",
headers={"retry-after": str(_QUERY_RETRY_SECONDS)},
)

output_generator = _stream_query_pages(query)
return StreamingResponse(
output_generator,
Expand All @@ -115,17 +148,24 @@ async def _stream_query_pages(query: StreamingQuery) -> AsyncIterator[str]:
When it takes longer than 15 seconds to get a response from the DB,
sends a keep-alive message to prevent clients from timing out.
"""
# `None` signals that there is no more data to send.
queue = asyncio.Queue[QueryExecuteResultData | None](1)
async with asyncio.TaskGroup() as tg:
# Run a background task to read from the DB and insert the result pages
# into a queue.
tg.create_task(_enqueue_query_pages(queue, query))
# Read the result pages from the queue and send them to the client,
# inserting a keep-alive message every 15 seconds if we are waiting a
# long time for the database.
async for message in _dequeue_query_pages_with_keepalive(queue):
yield message.model_dump_json() + "\n"
global _current_streaming_queries
try:
_current_streaming_queries += 1
await _block_query_for_unit_test()

# `None` signals that there is no more data to send.
queue = asyncio.Queue[QueryExecuteResultData | None](1)
async with asyncio.TaskGroup() as tg:
# Run a background task to read from the DB and insert the result
# pages into a queue.
tg.create_task(_enqueue_query_pages(queue, query))
# Read the result pages from the queue and send them to the client,
# inserting a keep-alive message every 15 seconds if we are waiting
# a long time for the database.
async for message in _dequeue_query_pages_with_keepalive(queue):
yield message.model_dump_json() + "\n"
finally:
_current_streaming_queries -= 1


async def _enqueue_query_pages(
Expand Down Expand Up @@ -163,3 +203,17 @@ async def _dequeue_query_pages_with_keepalive(
yield message
except TimeoutError:
yield QueryKeepAliveModel()


async def _block_retry_for_unit_test() -> None:
"""Will be overridden during unit tests to block the server,
in order to verify retry logic.
"""
pass


async def _block_query_for_unit_test() -> None:
"""Will be overridden during unit tests to block the server,
in order to verify maximum concurrency logic.
"""
pass
43 changes: 23 additions & 20 deletions python/lsst/daf/butler/tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,26 +122,29 @@ def create_test_server(
server_butler_factory._preload_direct_butler_cache = False
app.dependency_overrides[butler_factory_dependency] = lambda: server_butler_factory

client = TestClient(app)
client_without_error_propagation = TestClient(app, raise_server_exceptions=False)

remote_butler = _make_remote_butler(client)
remote_butler_without_error_propagation = _make_remote_butler(
client_without_error_propagation
)

direct_butler = Butler.from_config(config_file_path, writeable=True)
assert isinstance(direct_butler, DirectButler)
hybrid_butler = HybridButler(remote_butler, direct_butler)

yield TestServerInstance(
config_file_path=config_file_path,
client=client,
direct_butler=direct_butler,
remote_butler=remote_butler,
remote_butler_without_error_propagation=remote_butler_without_error_propagation,
hybrid_butler=hybrid_butler,
)
# Using TestClient in a context manager ensures that it uses
# the same async event loop for all requests -- otherwise it
# starts a new one on each request.
with TestClient(app) as client:
remote_butler = _make_remote_butler(client)

direct_butler = Butler.from_config(config_file_path, writeable=True)
assert isinstance(direct_butler, DirectButler)
hybrid_butler = HybridButler(remote_butler, direct_butler)

client_without_error_propagation = TestClient(app, raise_server_exceptions=False)
remote_butler_without_error_propagation = _make_remote_butler(
client_without_error_propagation
)

yield TestServerInstance(
config_file_path=config_file_path,
client=client,
direct_butler=direct_butler,
remote_butler=remote_butler,
remote_butler_without_error_propagation=remote_butler_without_error_propagation,
hybrid_butler=hybrid_butler,
)


def _make_remote_butler(client: TestClient) -> RemoteButler:
Expand Down
75 changes: 74 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import asyncio
import os.path
import tempfile
import threading
import unittest
import uuid
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import DEFAULT, AsyncMock, NonCallableMock, patch

from lsst.daf.butler.tests.dict_convertible_model import DictConvertibleModel

Expand All @@ -50,7 +54,6 @@
create_test_server = None
reason_text = str(e)

from unittest.mock import DEFAULT, NonCallableMock, patch

from lsst.daf.butler import (
Butler,
Expand Down Expand Up @@ -427,6 +430,76 @@ def test_query_keepalive(self):
self.assertGreaterEqual(mock_timeout.call_count, 3)
self.assertGreaterEqual(mock_keep_alive.call_count, 2)

@patch(
"lsst.daf.butler.remote_butler.server.handlers._query_streaming._MAXIMUM_CONCURRENT_STREAMING_QUERIES",

Check failure on line 434 in tests/test_server.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E501

line too long (111 > 110 characters)
1,
)
@patch("lsst.daf.butler.remote_butler.server.handlers._query_streaming._QUERY_RETRY_SECONDS", 1)
def test_query_retries(self):
"""Test that the server will send HTTP status 503 to put backpressure
on clients if it is overloaded, and that the client will retry if this
happens.
"""
query_event = threading.Event()
retry_event = asyncio.Event()

async def block_first_request() -> None:
# Signal the unit tests that we have reached the critical section
# in the server, where the first client has reserved the query
# slot.
query_event.set()
# Block inside the query, until the 2nd client has been forced to
# retry.
await retry_event.wait()

async def block_second_request() -> None:
# Release the first client, so it can finish its query and prevent
# this client from being blocked on the next go-round.
retry_event.set()

def do_query(butler: Butler) -> list[DatasetRef]:
return butler.query_datasets("bias", "imported_g")

with (
patch.object(
lsst.daf.butler.remote_butler.server.handlers._query_streaming,
"_block_query_for_unit_test",
new=AsyncMock(wraps=block_first_request),
) as mock_first_client,
patch.object(
lsst.daf.butler.remote_butler.server.handlers._query_streaming,
"_block_retry_for_unit_test",
new=AsyncMock(wraps=block_second_request),
) as mock_second_client,
ThreadPoolExecutor(max_workers=1) as exec1,
ThreadPoolExecutor(max_workers=1) as exec2,
):
first_butler = self.butler
second_butler = self.butler.clone()

# Run the first client up until the server starts executing its
# query.
future1 = exec1.submit(do_query, first_butler)
event_reached = query_event.wait(60)
if not event_reached:
raise TimeoutError("Server did not execute query logic as expected.")

# Start the second client, which will trigger the retry logic and
# release the first client to finish its query.
future2 = exec2.submit(do_query, second_butler)

result1 = future1.result(60)
result2 = future2.result(60)
self.assertEqual(len(result1), 3)
self.assertEqual(len(result2), 3)
# The original thread should have gone through this section, and
# then the 2nd thread after it retries.
self.assertEqual(mock_first_client.await_count, 2)
# We should have triggered the retry logic at least once, but it
# might occur multiple times depending how long the first client
# takes to finish.
self.assertGreaterEqual(mock_second_client.await_count, 1)

# TODO DM-46204: This can be removed once the RSP recommended image has
# been upgraded to a version that contains DM-46129.
def test_deprecated_collection_endpoints(self):
Expand Down

0 comments on commit 3f41ce9

Please sign in to comment.