Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-47889: Prevent DB connection pool exhaustion in Butler server #1124

Merged
merged 3 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 22 additions & 25 deletions python/lsst/daf/butler/registry/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,37 +163,34 @@ def makeEngine(
# multiple threads simultaneously. So we need to configure
# SQLAlchemy to pool connections for multi-threaded usage.
#
# This is not the maximum number of active connections --
# SQLAlchemy allows some additional overflow configured via the
# max_overflow parameter. pool_size is only the maximum number
# saved in the pool during periods of lower concurrency.
# This pool size was chosen to work well for services using
# FastAPI. FastAPI uses a thread pool of 40 by default, so this
# gives us a connection for each thread in the pool. Because Butler
# is currently sync-only, we won't ever be executing more queries
# than we have threads.
#
# This specific value for pool size was chosen somewhat arbitrarily
# -- there has not been any formal testing done to profile database
# concurrency. The value chosen may be somewhat lower than is
# optimal for service use cases. Some considerations:
# Connections are only created as they are needed, so in typical
# single-threaded Butler use only one connection will ever be
# created. Services with low peak concurrency may never create this
# many connections.
#
# 1. Connections are only created as they are needed, so in typical
# single-threaded Butler use only one connection will ever be
# created. Services with low peak concurrency may never create
# this many connections.
# 2. Most services using the Butler (including Butler
# server) are using FastAPI, which uses a thread pool of 40 by
# default. So when running at max concurrency we may have:
# * 10 connections checked out from the pool
# * 10 "overflow" connections re-created each time they are
# used.
# * 20 threads queued up, waiting for a connection, and
# potentially timing out if the other threads don't release
# their connections in a timely manner.
# 3. The main Butler databases at SLAC are run behind pgbouncer,
# so we can support a larger number of simultaneous connections
# than if we were connecting directly to Postgres.
# The main Butler databases at SLAC are run behind pgbouncer, so we
# can support a larger number of simultaneous connections than if
# we were connecting directly to Postgres.
#
# See
# https://docs.sqlalchemy.org/en/20/core/pooling.html#sqlalchemy.pool.QueuePool.__init__
# for more information on the behavior of this parameter.
pool_size=10,
pool_size=40,
# If we are experiencing heavy enough load that we overflow the
# connection pool, it will be harmful to start creating extra
# connections that we disconnect immediately after use.
# Connecting from scratch is fairly expensive, which is why we have
# a pool in the first place.
max_overflow=0,
# If the pool is full, this is the maximum number of seconds we
# will wait for a connection to become available before giving up.
pool_timeout=60,
# In combination with pool_pre_ping, prevent SQLAlchemy from
# unnecessarily reviving pooled connections that have gone stale.
# Setting this to true makes it always re-use the most recent
Expand Down
46 changes: 44 additions & 2 deletions python/lsst/daf/butler/remote_butler/_http_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

__all__ = ("RemoteButlerHttpConnection", "parse_model")

import time
import urllib.parse
from collections.abc import Iterator, Mapping
from contextlib import contextmanager
Expand Down Expand Up @@ -210,7 +211,7 @@
with the message as a subclass of ButlerUserError.
"""
try:
response = self._client.send(request.request)
response = self._send_with_retries(request, stream=False)
self._handle_http_status(response, request.request_id)
return response
except httpx.HTTPError as e:
Expand All @@ -219,7 +220,7 @@
@contextmanager
def _send_request_with_stream_response(self, request: _Request) -> Iterator[httpx.Response]:
try:
response = self._client.send(request.request, stream=True)
response = self._send_with_retries(request, stream=True)
try:
self._handle_http_status(response, request.request_id)
yield response
Expand All @@ -228,6 +229,21 @@
except httpx.HTTPError as e:
raise ButlerServerError(request.request_id) from e

def _send_with_retries(self, request: _Request, stream: bool) -> httpx.Response:
max_retry_time_seconds = 120
start_time = time.time()
while True:
response = self._client.send(request.request, stream=stream)
retry = _needs_retry(response)
time_remaining = max_retry_time_seconds - (time.time() - start_time)
if retry.retry and time_remaining > 0:
if stream:
response.close()
sleep_time = min(time_remaining, retry.delay_seconds)
time.sleep(sleep_time)
else:
return response

def _handle_http_status(self, response: httpx.Response, request_id: str) -> None:
if response.status_code == ERROR_STATUS_CODE:
# Raise an exception that the server has forwarded to the
Expand All @@ -245,6 +261,32 @@
response.raise_for_status()


@dataclass(frozen=True)
class _Retry:
retry: bool
delay_seconds: int


def _needs_retry(response: httpx.Response) -> _Retry:
# Handle a 503 Service Unavailable, sent by the server if it is
# overloaded, or a 429, sent by the server if the client
# triggers a rate limit.
if response.status_code == 503 or response.status_code == 429:
# Only retry if the server has instructed us to do so by sending a
# Retry-After header.
retry_after = response.headers.get("retry-after")
if retry_after is not None:
try:
# The HTTP standard also allows a date string here, but the
# Butler server only sends integer seconds.
delay_seconds = int(retry_after)
return _Retry(True, delay_seconds)
except ValueError:
pass

Check warning on line 285 in python/lsst/daf/butler/remote_butler/_http_connection.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/remote_butler/_http_connection.py#L284-L285

Added lines #L284 - L285 were not covered by tests

return _Retry(False, 0)


def parse_model(response: httpx.Response, model: type[_AnyPydanticModel]) -> _AnyPydanticModel:
"""Deserialize a Pydantic model from the body of an HTTP response.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def query_execute(
request: QueryExecuteRequestModel, factory: Factory = Depends(factory_dependency)
) -> StreamingResponse:
query = _StreamQueryDriverExecute(request, factory)
return execute_streaming_query(query)
return await execute_streaming_query(query)


class _QueryAllDatasetsContext(NamedTuple):
Expand Down Expand Up @@ -136,7 +136,7 @@ async def query_all_datasets_execute(
request: QueryAllDatasetsRequestModel, factory: Factory = Depends(factory_dependency)
) -> StreamingResponse:
query = _StreamQueryAllDatasets(request, factory)
return execute_streaming_query(query)
return await execute_streaming_query(query)


@query_router.post(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from contextlib import AbstractContextManager
from typing import Protocol, TypeVar

from fastapi import HTTPException
from fastapi.concurrency import contextmanager_in_threadpool, iterate_in_threadpool
from fastapi.responses import StreamingResponse
from lsst.daf.butler.remote_butler.server_models import (
Expand All @@ -43,11 +44,26 @@
from ...._exceptions import ButlerUserError
from ..._errors import serialize_butler_user_error

# Restrict the maximum number of streaming queries that can be running
# simultaneously, to prevent the database connection pool and the thread pool
# from being tied up indefinitely. Beyond this number, the server will return
# an HTTP 503 Service Unavailable with a Retry-After header. We are currently
# using the default FastAPI thread pool size of 40 (total) and have 40 maximum
# database connections (per Butler repository.)
_MAXIMUM_CONCURRENT_STREAMING_QUERIES = 25
# How long we ask callers to wait before trying their query again.
# The hope is that they will bounce to a less busy replica, so we don't want
# them to wait too long.
_QUERY_RETRY_SECONDS = 5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we planning to do randomized exponential backoff at some point in the future? Can we tell if this is the 100th time a client has attempted a retry?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really want to go exponential (at least by default), since most consumers of the Butler can't tolerate a request hanging out indefinitely. This is just supposed to buy us enough time for auto-scaling to get some more replicas booted up. When this case is triggered the scarce resource is threads and database connections, so I'm not really worried about the minimal amount of CPU and network needed to tell the caller to go away once every 5 seconds.

I added a cap on the retry time on the client side. Currently we don't have a way to track what a client is up to on the server side, but we will eventually have a redis instance (or Russ will) for more elaborate rate limiting.


# Alias this function so we can mock it during unit tests.
_timeout = asyncio.timeout

_TContext = TypeVar("_TContext")

# Count of active streaming queries.
_current_streaming_queries = 0


class StreamingQuery(Protocol[_TContext]):
"""Interface for queries that can return streaming results."""
Expand All @@ -67,7 +83,7 @@
"""


def execute_streaming_query(query: StreamingQuery) -> StreamingResponse:
async def execute_streaming_query(query: StreamingQuery) -> StreamingResponse:
"""Run a query, streaming the response incrementally, one page at a time,
as newline-separated chunks of JSON.

Expand Down Expand Up @@ -95,6 +111,22 @@
read -- ``StreamingQuery.execute()`` cannot be interrupted while it is
in the middle of reading a page.
"""
# Prevent an excessive number of streaming queries from jamming up the
# thread pool and database connection pool. We can't change the response
# code after starting the StreamingResponse, so we enforce this here.
#
# This creates a small chance that more than the expected number of
# streaming queries will be started, but there is no guarantee that the
# StreamingResponse generator function will ever be called, so we can't
# guarantee that we release the slot if we reserve one here.
if _current_streaming_queries >= _MAXIMUM_CONCURRENT_STREAMING_QUERIES:
await _block_retry_for_unit_test()
raise HTTPException(
status_code=503, # service temporarily unavailable
detail="The Butler Server is currently overloaded with requests.",
headers={"retry-after": str(_QUERY_RETRY_SECONDS)},
)

output_generator = _stream_query_pages(query)
return StreamingResponse(
output_generator,
Expand All @@ -115,17 +147,24 @@
When it takes longer than 15 seconds to get a response from the DB,
sends a keep-alive message to prevent clients from timing out.
"""
# `None` signals that there is no more data to send.
queue = asyncio.Queue[QueryExecuteResultData | None](1)
async with asyncio.TaskGroup() as tg:
# Run a background task to read from the DB and insert the result pages
# into a queue.
tg.create_task(_enqueue_query_pages(queue, query))
# Read the result pages from the queue and send them to the client,
# inserting a keep-alive message every 15 seconds if we are waiting a
# long time for the database.
async for message in _dequeue_query_pages_with_keepalive(queue):
yield message.model_dump_json() + "\n"
global _current_streaming_queries
try:
_current_streaming_queries += 1
await _block_query_for_unit_test()

# `None` signals that there is no more data to send.
queue = asyncio.Queue[QueryExecuteResultData | None](1)
async with asyncio.TaskGroup() as tg:
# Run a background task to read from the DB and insert the result
# pages into a queue.
tg.create_task(_enqueue_query_pages(queue, query))
# Read the result pages from the queue and send them to the client,
# inserting a keep-alive message every 15 seconds if we are waiting
# a long time for the database.
async for message in _dequeue_query_pages_with_keepalive(queue):
yield message.model_dump_json() + "\n"
finally:
_current_streaming_queries -= 1


async def _enqueue_query_pages(
Expand Down Expand Up @@ -163,3 +202,17 @@
yield message
except TimeoutError:
yield QueryKeepAliveModel()


async def _block_retry_for_unit_test() -> None:
"""Will be overridden during unit tests to block the server,
in order to verify retry logic.
"""
pass

Check warning on line 211 in python/lsst/daf/butler/remote_butler/server/handlers/_query_streaming.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/remote_butler/server/handlers/_query_streaming.py#L211

Added line #L211 was not covered by tests


async def _block_query_for_unit_test() -> None:
"""Will be overridden during unit tests to block the server,
in order to verify maximum concurrency logic.
"""
pass
43 changes: 23 additions & 20 deletions python/lsst/daf/butler/tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,26 +122,29 @@ def create_test_server(
server_butler_factory._preload_direct_butler_cache = False
app.dependency_overrides[butler_factory_dependency] = lambda: server_butler_factory

client = TestClient(app)
client_without_error_propagation = TestClient(app, raise_server_exceptions=False)

remote_butler = _make_remote_butler(client)
remote_butler_without_error_propagation = _make_remote_butler(
client_without_error_propagation
)

direct_butler = Butler.from_config(config_file_path, writeable=True)
assert isinstance(direct_butler, DirectButler)
hybrid_butler = HybridButler(remote_butler, direct_butler)

yield TestServerInstance(
config_file_path=config_file_path,
client=client,
direct_butler=direct_butler,
remote_butler=remote_butler,
remote_butler_without_error_propagation=remote_butler_without_error_propagation,
hybrid_butler=hybrid_butler,
)
# Using TestClient in a context manager ensures that it uses
# the same async event loop for all requests -- otherwise it
# starts a new one on each request.
with TestClient(app) as client:
remote_butler = _make_remote_butler(client)

direct_butler = Butler.from_config(config_file_path, writeable=True)
assert isinstance(direct_butler, DirectButler)
hybrid_butler = HybridButler(remote_butler, direct_butler)

client_without_error_propagation = TestClient(app, raise_server_exceptions=False)
remote_butler_without_error_propagation = _make_remote_butler(
client_without_error_propagation
)

yield TestServerInstance(
config_file_path=config_file_path,
client=client,
direct_butler=direct_butler,
remote_butler=remote_butler,
remote_butler_without_error_propagation=remote_butler_without_error_propagation,
hybrid_butler=hybrid_butler,
)


def _make_remote_butler(client: TestClient) -> RemoteButler:
Expand Down
Loading
Loading