diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py index e61d5a8dcb..2d70d49fc3 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py @@ -29,13 +29,11 @@ __all__ = ("query_router",) -import asyncio -from collections.abc import AsyncIterator, Iterator -from contextlib import AbstractContextManager, contextmanager -from typing import NamedTuple, Protocol, TypeVar +from collections.abc import Iterator +from contextlib import contextmanager +from typing import NamedTuple from fastapi import APIRouter, Depends -from fastapi.concurrency import contextmanager_in_threadpool, iterate_in_threadpool from fastapi.responses import StreamingResponse from lsst.daf.butler import DataCoordinate, DimensionGroup from lsst.daf.butler.remote_butler.server_models import ( @@ -43,47 +41,21 @@ QueryAnyResponseModel, QueryCountRequestModel, QueryCountResponseModel, - QueryErrorResultModel, QueryExecuteRequestModel, QueryExecuteResultData, QueryExplainRequestModel, QueryExplainResponseModel, QueryInputs, - QueryKeepAliveModel, ) -from ...._exceptions import ButlerUserError from ....queries.driver import QueryDriver, QueryTree -from ..._errors import serialize_butler_user_error from .._dependencies import factory_dependency from .._factory import Factory from ._query_serialization import convert_query_page +from ._query_streaming import StreamingQuery, execute_streaming_query query_router = APIRouter() -# Alias this function so we can mock it during unit tests. -_timeout = asyncio.timeout - -_TContext = TypeVar("_TContext") - - -class StreamingQuery(Protocol[_TContext]): - """Interface for queries that can return streaming results.""" - - def setup(self) -> AbstractContextManager[_TContext]: - """Context manager that sets up any resources used to execute the - query. - """ - - def execute(self, context: _TContext) -> Iterator[QueryExecuteResultData]: - """Execute the database query and and return pages of results. - - Parameters - ---------- - context : generic - Value returned by the call to ``setup()``. - """ - class _StreamQueryDriverExecute(StreamingQuery): """Wrapper to call `QueryDriver.execute` from async stream handler.""" @@ -109,82 +81,7 @@ async def query_execute( request: QueryExecuteRequestModel, factory: Factory = Depends(factory_dependency) ) -> StreamingResponse: query = _StreamQueryDriverExecute(request, factory) - return _execute_streaming_query(query) - - -def _execute_streaming_query(query: StreamingQuery) -> StreamingResponse: - # We write the response incrementally, one page at a time, as - # newline-separated chunks of JSON. This allows clients to start - # reading results earlier and prevents the server from exhausting - # all its memory buffering rows from large queries. - output_generator = _stream_query_pages(query) - return StreamingResponse( - output_generator, - media_type="application/jsonlines", - headers={ - # Instruct the Kubernetes ingress to not buffer the response, - # so that keep-alives reach the client promptly. - "X-Accel-Buffering": "no" - }, - ) - - -async def _stream_query_pages(query: StreamingQuery) -> AsyncIterator[str]: - """Stream the query output with one page object per line, as - newline-delimited JSON records in the "JSON Lines" format - (https://jsonlines.org/). - - When it takes longer than 15 seconds to get a response from the DB, - sends a keep-alive message to prevent clients from timing out. - """ - # `None` signals that there is no more data to send. - queue = asyncio.Queue[QueryExecuteResultData | None](1) - async with asyncio.TaskGroup() as tg: - # Run a background task to read from the DB and insert the result pages - # into a queue. - tg.create_task(_enqueue_query_pages(queue, query)) - # Read the result pages from the queue and send them to the client, - # inserting a keep-alive message every 15 seconds if we are waiting a - # long time for the database. - async for message in _dequeue_query_pages_with_keepalive(queue): - yield message.model_dump_json() + "\n" - - -async def _enqueue_query_pages( - queue: asyncio.Queue[QueryExecuteResultData | None], query: StreamingQuery -) -> None: - """Set up a QueryDriver to run the query, and copy the results into a - queue. Send `None` to the queue when there is no more data to read. - """ - try: - async with contextmanager_in_threadpool(query.setup()) as ctx: - async for page in iterate_in_threadpool(query.execute(ctx)): - await queue.put(page) - except ButlerUserError as e: - # If a user-facing error occurs, serialize it and send it to the - # client. - await queue.put(QueryErrorResultModel(error=serialize_butler_user_error(e))) - - # Signal that there is no more data to read. - await queue.put(None) - - -async def _dequeue_query_pages_with_keepalive( - queue: asyncio.Queue[QueryExecuteResultData | None], -) -> AsyncIterator[QueryExecuteResultData]: - """Read and return messages from the given queue until the end-of-stream - message `None` is reached. If the producer is taking a long time, returns - a keep-alive message every 15 seconds while we are waiting. - """ - while True: - try: - async with _timeout(15): - message = await queue.get() - if message is None: - return - yield message - except TimeoutError: - yield QueryKeepAliveModel() + return execute_streaming_query(query) @query_router.post( diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_query_streaming.py b/python/lsst/daf/butler/remote_butler/server/handlers/_query_streaming.py new file mode 100644 index 0000000000..04ef3a9c32 --- /dev/null +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_query_streaming.py @@ -0,0 +1,165 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator, Iterator +from contextlib import AbstractContextManager +from typing import Protocol, TypeVar + +from fastapi.concurrency import contextmanager_in_threadpool, iterate_in_threadpool +from fastapi.responses import StreamingResponse +from lsst.daf.butler.remote_butler.server_models import ( + QueryErrorResultModel, + QueryExecuteResultData, + QueryKeepAliveModel, +) + +from ...._exceptions import ButlerUserError +from ..._errors import serialize_butler_user_error + +# Alias this function so we can mock it during unit tests. +_timeout = asyncio.timeout + +_TContext = TypeVar("_TContext") + + +class StreamingQuery(Protocol[_TContext]): + """Interface for queries that can return streaming results.""" + + def setup(self) -> AbstractContextManager[_TContext]: + """Context manager that sets up any resources used to execute the + query. + """ + + def execute(self, context: _TContext) -> Iterator[QueryExecuteResultData]: + """Execute the database query and and return pages of results. + + Parameters + ---------- + context : generic + Value returned by the call to ``setup()``. + """ + + +def execute_streaming_query(query: StreamingQuery) -> StreamingResponse: + """Run a query, streaming the response incrementally, one page at a time, + as newline-separated chunks of JSON. + + Parameters + ---------- + query : ``StreamingQuery`` + Callers should define a class implementing the ``StreamingQuery`` + protocol to specify the inner logic that will be called during + query execution. + + Returns + ------- + response : `fastapi.StreamingResponse` + FastAPI streaming response that can be returned by a route handler. + + Notes + ----- + - Streaming the response allows clients to start reading results earlier + and prevents the server from exhausting all its memory buffering rows + from large queries. + - If the query is taking a long time to execute, we insert a keepalive + message in the JSON stream every 15 seconds. + - If the caller closes the HTTP connection, async cancellation is + triggered by FastAPI. The query will be cancelled after the next page is + read -- ``StreamingQuery.execute()`` cannot be interrupted while it is + in the middle of reading a page. + """ + output_generator = _stream_query_pages(query) + return StreamingResponse( + output_generator, + media_type="application/jsonlines", + headers={ + # Instruct the Kubernetes ingress to not buffer the response, + # so that keep-alives reach the client promptly. + "X-Accel-Buffering": "no" + }, + ) + + +async def _stream_query_pages(query: StreamingQuery) -> AsyncIterator[str]: + """Stream the query output with one page object per line, as + newline-delimited JSON records in the "JSON Lines" format + (https://jsonlines.org/). + + When it takes longer than 15 seconds to get a response from the DB, + sends a keep-alive message to prevent clients from timing out. + """ + # `None` signals that there is no more data to send. + queue = asyncio.Queue[QueryExecuteResultData | None](1) + async with asyncio.TaskGroup() as tg: + # Run a background task to read from the DB and insert the result pages + # into a queue. + tg.create_task(_enqueue_query_pages(queue, query)) + # Read the result pages from the queue and send them to the client, + # inserting a keep-alive message every 15 seconds if we are waiting a + # long time for the database. + async for message in _dequeue_query_pages_with_keepalive(queue): + yield message.model_dump_json() + "\n" + + +async def _enqueue_query_pages( + queue: asyncio.Queue[QueryExecuteResultData | None], query: StreamingQuery +) -> None: + """Set up a QueryDriver to run the query, and copy the results into a + queue. Send `None` to the queue when there is no more data to read. + """ + try: + async with contextmanager_in_threadpool(query.setup()) as ctx: + async for page in iterate_in_threadpool(query.execute(ctx)): + await queue.put(page) + except ButlerUserError as e: + # If a user-facing error occurs, serialize it and send it to the + # client. + await queue.put(QueryErrorResultModel(error=serialize_butler_user_error(e))) + + # Signal that there is no more data to read. + await queue.put(None) + + +async def _dequeue_query_pages_with_keepalive( + queue: asyncio.Queue[QueryExecuteResultData | None], +) -> AsyncIterator[QueryExecuteResultData]: + """Read and return messages from the given queue until the end-of-stream + message `None` is reached. If the producer is taking a long time, returns + a keep-alive message every 15 seconds while we are waiting. + """ + while True: + try: + async with _timeout(15): + message = await queue.get() + if message is None: + return + yield message + except TimeoutError: + yield QueryKeepAliveModel() diff --git a/tests/test_server.py b/tests/test_server.py index 96fd7de6b1..17a18edf53 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -35,7 +35,7 @@ try: # Failing to import any of these should disable the tests. import lsst.daf.butler.remote_butler._query_driver - import lsst.daf.butler.remote_butler.server.handlers._external_query + import lsst.daf.butler.remote_butler.server.handlers._query_streaming import safir.dependencies.logger from fastapi.testclient import TestClient from lsst.daf.butler.remote_butler import RemoteButler @@ -413,7 +413,7 @@ def test_query_keepalive(self): # Normally it takes 15 seconds for a timeout -- mock it to trigger # immediately instead. with patch.object( - lsst.daf.butler.remote_butler.server.handlers._external_query, "_timeout" + lsst.daf.butler.remote_butler.server.handlers._query_streaming, "_timeout" ) as mock_timeout: # Hook into QueryDriver to track the number of keep-alives we have # seen.