Skip to content

Commit

Permalink
Add server-side implementation of query_all_datasets
Browse files Browse the repository at this point in the history
query_all_datasets can potentially involve hundreds or thousands of separate dataset queries.  We don't want clients slamming the server with that many HTTP requests, so add a server-side endpoint that can handle these queries in a single request.
  • Loading branch information
dhirving committed Nov 12, 2024
1 parent c8fd5f7 commit 32c647e
Show file tree
Hide file tree
Showing 11 changed files with 199 additions and 47 deletions.
44 changes: 27 additions & 17 deletions python/lsst/daf/butler/_butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
__all__ = ["Butler"]

from abc import abstractmethod
from collections.abc import Collection, Iterable, Mapping, Sequence
from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
from contextlib import AbstractContextManager
from types import EllipsisType
from typing import TYPE_CHECKING, Any, TextIO
Expand All @@ -47,6 +47,7 @@
from ._config import Config, ConfigSubset
from ._exceptions import EmptyQueryResultError, InvalidQueryError
from ._limited_butler import LimitedButler
from ._query_all_datasets import QueryAllDatasetsParameters
from .datastore import Datastore
from .dimensions import DataCoordinate, DimensionConfig
from .registry import RegistryConfig, _RegistryFactory
Expand Down Expand Up @@ -1946,33 +1947,36 @@ def _query_all_datasets(
include dimension records (`DataCoordinate.hasRecords` will be
`False`).
"""
from ._query_all_datasets import QueryAllDatasetsParameters, query_all_datasets

if collections is None:
collections = list(self.collections.defaults)
else:
collections = list(ensure_iterable(collections))

if bind is None:
bind = {}
if data_id is None:
data_id = {}

warn_limit = False
if limit is not None and limit < 0:
# Add one to the limit so we can detect if we have exceeded it.
limit = abs(limit) + 1
warn_limit = True

result = []
with self.query() as query:
args = QueryAllDatasetsParameters(
collections=collections,
name=name,
find_first=find_first,
data_id=data_id,
where=where,
limit=limit,
bind=bind,
kwargs=kwargs,
)
for page in query_all_datasets(self, query, args):
result.extend(page.data)
args = QueryAllDatasetsParameters(
collections=collections,
name=list(ensure_iterable(name)),
find_first=find_first,
data_id=data_id,
where=where,
limit=limit,
bind=bind,
kwargs=kwargs,
)
with self._query_all_datasets_by_page(args) as pages:
result = []
for page in pages:
result.extend(page)

if warn_limit and limit is not None and len(result) >= limit:
# Remove the extra dataset we added for the limit check.
Expand All @@ -1981,6 +1985,12 @@ def _query_all_datasets(

return result

@abstractmethod
def _query_all_datasets_by_page(
self, args: QueryAllDatasetsParameters
) -> AbstractContextManager[Iterator[list[DatasetRef]]]:
raise NotImplementedError()

def clone(
self,
*,
Expand Down
24 changes: 12 additions & 12 deletions python/lsst/daf/butler/_query_all_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,21 @@

import dataclasses
import logging
from collections.abc import Iterable, Iterator, Mapping
from typing import Any, NamedTuple
from collections.abc import Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, NamedTuple

from lsst.utils.iteration import ensure_iterable

from ._butler import Butler
from ._dataset_ref import DatasetRef
from ._exceptions import InvalidQueryError, MissingDatasetTypeError
from .dimensions import DataId, DataIdValue
from .queries import Query
from .utils import has_globs

if TYPE_CHECKING:
from ._butler import Butler


_LOG = logging.getLogger(__name__)


Expand All @@ -57,12 +60,12 @@ class QueryAllDatasetsParameters:
the same meaning as that function unless noted below.
"""

collections: list[str]
name: str | Iterable[str]
collections: Sequence[str]
name: Sequence[str]
find_first: bool
data_id: DataId | None
data_id: DataId
where: str
bind: Mapping[str, Any] | None
bind: Mapping[str, Any]
limit: int | None
"""
Upper limit on the number of returned records. `None` can be used
Expand Down Expand Up @@ -107,9 +110,6 @@ def query_all_datasets(
"""
if args.find_first and has_globs(args.collections):
raise InvalidQueryError("Can not use wildcards in collections when find_first=True")
data_id = args.data_id
if data_id is None:
data_id = {}

dataset_type_query = list(ensure_iterable(args.name))
dataset_type_collections = _filter_collections_and_dataset_types(
Expand All @@ -121,7 +121,7 @@ def query_all_datasets(
_LOG.debug("Querying dataset type %s", dt)
results = (
query.datasets(dt, filtered_collections, find_first=args.find_first)
.where(data_id, args.where, args.kwargs, bind=args.bind)
.where(args.data_id, args.where, args.kwargs, bind=args.bind)
.limit(limit)
)

Expand All @@ -137,7 +137,7 @@ def query_all_datasets(


def _filter_collections_and_dataset_types(
butler: Butler, collections: list[str], dataset_type_query: list[str]
butler: Butler, collections: Sequence[str], dataset_type_query: Sequence[str]
) -> Mapping[str, list[str]]:
"""For each dataset type matching the query, filter down the given
collections to only those that might actually contain datasets of the given
Expand Down
9 changes: 9 additions & 0 deletions python/lsst/daf/butler/direct_butler/_direct_butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from .._exceptions import DatasetNotFoundError, DimensionValueError, EmptyQueryResultError, ValidationError
from .._file_dataset import FileDataset
from .._limited_butler import LimitedButler
from .._query_all_datasets import QueryAllDatasetsParameters, query_all_datasets
from .._registry_shim import RegistryShim
from .._storage_class import StorageClass, StorageClassFactory
from .._timespan import Timespan
Expand Down Expand Up @@ -2319,6 +2320,14 @@ def _query_driver(
"""
return self._registry._query_driver(default_collections, default_data_id)

@contextlib.contextmanager
def _query_all_datasets_by_page(
self, args: QueryAllDatasetsParameters
) -> Iterator[Iterator[list[DatasetRef]]]:
with self.query() as query:
pages = query_all_datasets(self, query, args)
yield iter(page.data for page in pages)

def _preload_cache(self) -> None:
"""Immediately load caches that are used for common operations."""
self._registry.preload_cache()
Expand Down
24 changes: 24 additions & 0 deletions python/lsst/daf/butler/remote_butler/_remote_butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,20 @@
from .._dataset_type import DatasetType
from .._deferredDatasetHandle import DeferredDatasetHandle
from .._exceptions import DatasetNotFoundError
from .._query_all_datasets import QueryAllDatasetsParameters
from .._storage_class import StorageClass, StorageClassFactory
from .._utilities.locked_object import LockedObject
from ..datastore import DatasetRefURIs, DatastoreConfig
from ..datastore.cache_manager import AbstractDatastoreCacheManager, DatastoreCacheManager
from ..dimensions import DataIdValue, DimensionConfig, DimensionUniverse, SerializedDataId
from ..queries import Query
from ..queries.tree import make_column_literal
from ..registry import CollectionArgType, NoDefaultCollectionError, Registry, RegistryDefaults
from ._collection_args import convert_collection_arg_to_glob_string_list
from ._defaults import DefaultsHolder
from ._http_connection import RemoteButlerHttpConnection, parse_model, quote_path_variable
from ._query_driver import RemoteQueryDriver
from ._query_results import convert_dataset_ref_results, read_query_results
from ._ref_utils import apply_storage_class_override, normalize_dataset_type_name, simplify_dataId
from ._registry import RemoteButlerRegistry
from ._remote_butler_collections import RemoteButlerCollections
Expand All @@ -79,6 +82,7 @@
GetFileByDataIdRequestModel,
GetFileResponseModel,
GetUniverseResponseModel,
QueryAllDatasetsRequestModel,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -628,6 +632,26 @@ def query(self) -> Iterator[Query]:
query = Query(driver)
yield query

@contextmanager
def _query_all_datasets_by_page(
self, args: QueryAllDatasetsParameters
) -> Iterator[Iterator[list[DatasetRef]]]:
universe = self.dimensions

request = QueryAllDatasetsRequestModel(
collections=self._normalize_collections(args.collections),
name=[normalize_dataset_type_name(name) for name in args.name],
find_first=args.find_first,
data_id=simplify_dataId(args.data_id, args.kwargs),
default_data_id=self._serialize_default_data_id(),
where=args.where,
bind={k: make_column_literal(v) for k, v in args.bind.items()},
limit=args.limit,
)
with self._connection.post_with_stream_response("query/all_datasets", request) as response:
pages = read_query_results(response)
yield (convert_dataset_ref_results(page, universe) for page in pages)

def pruneDatasets(
self,
refs: Iterable[DatasetRef],
Expand Down
15 changes: 4 additions & 11 deletions python/lsst/daf/butler/remote_butler/server/handlers/_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from __future__ import annotations

from lsst.daf.butler.remote_butler.server.handlers._utils import set_default_data_id

__all__ = ()

import uuid
Expand Down Expand Up @@ -55,8 +57,6 @@
)

from ...._exceptions import DatasetNotFoundError
from ....dimensions import DataCoordinate, SerializedDataId
from ....registry import RegistryDefaults
from .._dependencies import factory_dependency
from .._factory import Factory

Expand Down Expand Up @@ -134,7 +134,7 @@ def find_dataset(
factory: Factory = Depends(factory_dependency),
) -> FindDatasetResponseModel:
butler = factory.create_butler()
_set_default_data_id(butler, query.default_data_id)
set_default_data_id(butler, query.default_data_id)
ref = butler.find_dataset(
query.dataset_type,
query.data_id,
Expand Down Expand Up @@ -182,7 +182,7 @@ def get_file_by_data_id(
factory: Factory = Depends(factory_dependency),
) -> GetFileResponseModel:
butler = factory.create_butler()
_set_default_data_id(butler, request.default_data_id)
set_default_data_id(butler, request.default_data_id)
ref = butler._findDatasetRef(
datasetRefOrType=request.dataset_type,
dataId=request.data_id,
Expand Down Expand Up @@ -285,10 +285,3 @@ def query_dataset_types(
return QueryDatasetTypesResponseModel(
dataset_types=[dt.to_simple() for dt in dataset_types], missing=missing
)


def _set_default_data_id(butler: Butler, data_id: SerializedDataId) -> None:
"""Set the default data ID values used for lookups in the given Butler."""
butler.registry.defaults = RegistryDefaults.from_data_id(
DataCoordinate.standardize(data_id, universe=butler.dimensions)
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@

from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from lsst.daf.butler import DataCoordinate, DimensionGroup
from lsst.daf.butler import Butler, DataCoordinate, DimensionGroup
from lsst.daf.butler.remote_butler.server_models import (
DatasetRefResultModel,
QueryAllDatasetsRequestModel,
QueryAnyRequestModel,
QueryAnyResponseModel,
QueryCountRequestModel,
Expand All @@ -48,11 +50,14 @@
QueryInputs,
)

from ...._query_all_datasets import QueryAllDatasetsParameters, query_all_datasets
from ....queries import Query
from ....queries.driver import QueryDriver, QueryTree
from .._dependencies import factory_dependency
from .._factory import Factory
from ._query_serialization import convert_query_page
from ._query_streaming import StreamingQuery, execute_streaming_query
from ._utils import set_default_data_id

query_router = APIRouter()

Expand Down Expand Up @@ -84,6 +89,50 @@ async def query_execute(
return execute_streaming_query(query)


class _QueryAllDatasetsContext(NamedTuple):
butler: Butler
query: Query


class _StreamQueryAllDatasets(StreamingQuery):
def __init__(self, request: QueryAllDatasetsRequestModel, factory: Factory) -> None:
self._request = request
self._factory = factory

@contextmanager
def setup(self) -> Iterator[_QueryAllDatasetsContext]:
butler = self._factory.create_butler()
set_default_data_id(butler, self._request.default_data_id)
with butler.query() as query:
yield _QueryAllDatasetsContext(butler, query)

def execute(self, ctx: _QueryAllDatasetsContext) -> Iterator[QueryExecuteResultData]:
request = self._request
bind = {k: v.get_literal_value() for k, v in request.bind.items()}
args = QueryAllDatasetsParameters(
collections=request.collections,
name=request.name,
find_first=request.find_first,
data_id=request.data_id,
where=request.where,
bind=bind,
limit=request.limit,
)
pages = query_all_datasets(ctx.butler, ctx.query, args)
for page in pages:
yield DatasetRefResultModel.from_refs(page.data)


@query_router.post(
"/v1/query/all_datasets", summary="Query the Butler database across multiple dataset types."
)
async def query_all_datasets_execute(
request: QueryAllDatasetsRequestModel, factory: Factory = Depends(factory_dependency)
) -> StreamingResponse:
query = _StreamQueryAllDatasets(request, factory)
return execute_streaming_query(query)


@query_router.post(
"/v1/query/count",
summary="Query the Butler database and return a count of rows that would be returned.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResult
return DataCoordinateResultModel(rows=[coordinate.to_simple() for coordinate in page.rows])
case "dataset_ref":
assert isinstance(page, DatasetRefResultPage)
return DatasetRefResultModel(rows=[ref.to_simple() for ref in page.rows])
return DatasetRefResultModel.from_refs(page.rows)
case "general":
assert isinstance(page, GeneralResultPage)
return _convert_general_result(page)
Expand Down
Loading

0 comments on commit 32c647e

Please sign in to comment.