Skip to content

Commit

Permalink
Move query streaming logic to its own file
Browse files Browse the repository at this point in the history
After the refactor in the previous commit, this is somewhat independent of the query routes.
  • Loading branch information
dhirving committed Nov 6, 2024
1 parent 5b82702 commit 07afc52
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,61 +29,33 @@

__all__ = ("query_router",)

import asyncio
from collections.abc import AsyncIterator, Iterator
from contextlib import AbstractContextManager, contextmanager
from typing import NamedTuple, Protocol, TypeVar
from collections.abc import Iterator
from contextlib import contextmanager
from typing import NamedTuple

from fastapi import APIRouter, Depends
from fastapi.concurrency import contextmanager_in_threadpool, iterate_in_threadpool
from fastapi.responses import StreamingResponse
from lsst.daf.butler import DataCoordinate, DimensionGroup
from lsst.daf.butler.remote_butler.server_models import (
QueryAnyRequestModel,
QueryAnyResponseModel,
QueryCountRequestModel,
QueryCountResponseModel,
QueryErrorResultModel,
QueryExecuteRequestModel,
QueryExecuteResultData,
QueryExplainRequestModel,
QueryExplainResponseModel,
QueryInputs,
QueryKeepAliveModel,
)

from ...._exceptions import ButlerUserError
from ....queries.driver import QueryDriver, QueryTree
from ..._errors import serialize_butler_user_error
from .._dependencies import factory_dependency
from .._factory import Factory
from ._query_serialization import convert_query_page
from ._query_streaming import StreamingQuery, execute_streaming_query

query_router = APIRouter()

# Alias this function so we can mock it during unit tests.
_timeout = asyncio.timeout

_TContext = TypeVar("_TContext")


class StreamingQuery(Protocol[_TContext]):
"""Interface for queries that can return streaming results."""

def setup(self) -> AbstractContextManager[_TContext]:
"""Context manager that sets up any resources used to execute the
query.
"""

def execute(self, context: _TContext) -> Iterator[QueryExecuteResultData]:
"""Execute the database query and and return pages of results.
Parameters
----------
context : generic
Value returned by the call to ``setup()``.
"""


class _StreamQueryDriverExecute(StreamingQuery):
"""Wrapper to call `QueryDriver.execute` from async stream handler."""
Expand All @@ -109,82 +81,7 @@ async def query_execute(
request: QueryExecuteRequestModel, factory: Factory = Depends(factory_dependency)
) -> StreamingResponse:
query = _StreamQueryDriverExecute(request, factory)
return _execute_streaming_query(query)


def _execute_streaming_query(query: StreamingQuery) -> StreamingResponse:
# We write the response incrementally, one page at a time, as
# newline-separated chunks of JSON. This allows clients to start
# reading results earlier and prevents the server from exhausting
# all its memory buffering rows from large queries.
output_generator = _stream_query_pages(query)
return StreamingResponse(
output_generator,
media_type="application/jsonlines",
headers={
# Instruct the Kubernetes ingress to not buffer the response,
# so that keep-alives reach the client promptly.
"X-Accel-Buffering": "no"
},
)


async def _stream_query_pages(query: StreamingQuery) -> AsyncIterator[str]:
"""Stream the query output with one page object per line, as
newline-delimited JSON records in the "JSON Lines" format
(https://jsonlines.org/).
When it takes longer than 15 seconds to get a response from the DB,
sends a keep-alive message to prevent clients from timing out.
"""
# `None` signals that there is no more data to send.
queue = asyncio.Queue[QueryExecuteResultData | None](1)
async with asyncio.TaskGroup() as tg:
# Run a background task to read from the DB and insert the result pages
# into a queue.
tg.create_task(_enqueue_query_pages(queue, query))
# Read the result pages from the queue and send them to the client,
# inserting a keep-alive message every 15 seconds if we are waiting a
# long time for the database.
async for message in _dequeue_query_pages_with_keepalive(queue):
yield message.model_dump_json() + "\n"


async def _enqueue_query_pages(
queue: asyncio.Queue[QueryExecuteResultData | None], query: StreamingQuery
) -> None:
"""Set up a QueryDriver to run the query, and copy the results into a
queue. Send `None` to the queue when there is no more data to read.
"""
try:
async with contextmanager_in_threadpool(query.setup()) as ctx:
async for page in iterate_in_threadpool(query.execute(ctx)):
await queue.put(page)
except ButlerUserError as e:
# If a user-facing error occurs, serialize it and send it to the
# client.
await queue.put(QueryErrorResultModel(error=serialize_butler_user_error(e)))

# Signal that there is no more data to read.
await queue.put(None)


async def _dequeue_query_pages_with_keepalive(
queue: asyncio.Queue[QueryExecuteResultData | None],
) -> AsyncIterator[QueryExecuteResultData]:
"""Read and return messages from the given queue until the end-of-stream
message `None` is reached. If the producer is taking a long time, returns
a keep-alive message every 15 seconds while we are waiting.
"""
while True:
try:
async with _timeout(15):
message = await queue.get()
if message is None:
return
yield message
except TimeoutError:
yield QueryKeepAliveModel()
return execute_streaming_query(query)


@query_router.post(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# This file is part of daf_butler.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This software is dual licensed under the GNU General Public License and also
# under a 3-clause BSD license. Recipients may choose which of these licenses
# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
# respectively. If you choose the GPL option then the following text applies
# (but note that there is still no warranty even if you opt for BSD instead):
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

import asyncio
from collections.abc import AsyncIterator, Iterator
from contextlib import AbstractContextManager
from typing import Protocol, TypeVar

from fastapi.concurrency import contextmanager_in_threadpool, iterate_in_threadpool
from fastapi.responses import StreamingResponse
from lsst.daf.butler.remote_butler.server_models import (
QueryErrorResultModel,
QueryExecuteResultData,
QueryKeepAliveModel,
)

from ...._exceptions import ButlerUserError
from ..._errors import serialize_butler_user_error

# Alias this function so we can mock it during unit tests.
_timeout = asyncio.timeout

_TContext = TypeVar("_TContext")


class StreamingQuery(Protocol[_TContext]):
"""Interface for queries that can return streaming results."""

def setup(self) -> AbstractContextManager[_TContext]:
"""Context manager that sets up any resources used to execute the
query.
"""

def execute(self, context: _TContext) -> Iterator[QueryExecuteResultData]:
"""Execute the database query and and return pages of results.
Parameters
----------
context : generic
Value returned by the call to ``setup()``.
"""


def execute_streaming_query(query: StreamingQuery) -> StreamingResponse:
"""Run a query, streaming the response incrementally, one page at a time,
as newline-separated chunks of JSON.
Parameters
----------
query : ``StreamingQuery``
Callers should define a class implementing the ``StreamingQuery``
protocol to specify the inner logic that will be called during
query execution.
Returns
-------
response : `fastapi.StreamingResponse`
FastAPI streaming response that can be returned by a route handler.
Notes
-----
- Streaming the response allows clients to start reading results earlier
and prevents the server from exhausting all its memory buffering rows
from large queries.
- If the query is taking a long time to execute, we insert a keepalive
message in the JSON stream every 15 seconds.
- If the caller closes the HTTP connection, async cancellation is
triggered by FastAPI. The query will be cancelled after the next page is
read -- ``StreamingQuery.execute()`` cannot be interrupted while it is
in the middle of reading a page.
"""
output_generator = _stream_query_pages(query)
return StreamingResponse(
output_generator,
media_type="application/jsonlines",
headers={
# Instruct the Kubernetes ingress to not buffer the response,
# so that keep-alives reach the client promptly.
"X-Accel-Buffering": "no"
},
)


async def _stream_query_pages(query: StreamingQuery) -> AsyncIterator[str]:
"""Stream the query output with one page object per line, as
newline-delimited JSON records in the "JSON Lines" format
(https://jsonlines.org/).
When it takes longer than 15 seconds to get a response from the DB,
sends a keep-alive message to prevent clients from timing out.
"""
# `None` signals that there is no more data to send.
queue = asyncio.Queue[QueryExecuteResultData | None](1)
async with asyncio.TaskGroup() as tg:
# Run a background task to read from the DB and insert the result pages
# into a queue.
tg.create_task(_enqueue_query_pages(queue, query))
# Read the result pages from the queue and send them to the client,
# inserting a keep-alive message every 15 seconds if we are waiting a
# long time for the database.
async for message in _dequeue_query_pages_with_keepalive(queue):
yield message.model_dump_json() + "\n"


async def _enqueue_query_pages(
queue: asyncio.Queue[QueryExecuteResultData | None], query: StreamingQuery
) -> None:
"""Set up a QueryDriver to run the query, and copy the results into a
queue. Send `None` to the queue when there is no more data to read.
"""
try:
async with contextmanager_in_threadpool(query.setup()) as ctx:
async for page in iterate_in_threadpool(query.execute(ctx)):
await queue.put(page)
except ButlerUserError as e:
# If a user-facing error occurs, serialize it and send it to the
# client.
await queue.put(QueryErrorResultModel(error=serialize_butler_user_error(e)))

# Signal that there is no more data to read.
await queue.put(None)


async def _dequeue_query_pages_with_keepalive(
queue: asyncio.Queue[QueryExecuteResultData | None],
) -> AsyncIterator[QueryExecuteResultData]:
"""Read and return messages from the given queue until the end-of-stream
message `None` is reached. If the producer is taking a long time, returns
a keep-alive message every 15 seconds while we are waiting.
"""
while True:
try:
async with _timeout(15):
message = await queue.get()
if message is None:
return
yield message
except TimeoutError:
yield QueryKeepAliveModel()
4 changes: 2 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
try:
# Failing to import any of these should disable the tests.
import lsst.daf.butler.remote_butler._query_driver
import lsst.daf.butler.remote_butler.server.handlers._external_query
import lsst.daf.butler.remote_butler.server.handlers._query_streaming
import safir.dependencies.logger
from fastapi.testclient import TestClient
from lsst.daf.butler.remote_butler import RemoteButler
Expand Down Expand Up @@ -413,7 +413,7 @@ def test_query_keepalive(self):
# Normally it takes 15 seconds for a timeout -- mock it to trigger
# immediately instead.
with patch.object(
lsst.daf.butler.remote_butler.server.handlers._external_query, "_timeout"
lsst.daf.butler.remote_butler.server.handlers._query_streaming, "_timeout"
) as mock_timeout:
# Hook into QueryDriver to track the number of keep-alives we have
# seen.
Expand Down

0 comments on commit 07afc52

Please sign in to comment.