From 97d3e8811f17bb46197e555795628783df5cf151 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Wed, 31 Jan 2024 14:33:14 -0500 Subject: [PATCH] Rewrite and disable Butler._query interfaces. This includes: - replacing the Query and *QueryResults ABCs with concrete classes that delegate to another QueryDriver ABC; - substantially reworking the public Query and *QueryResults interfaces, mostly to minimize the number of different ways to do various things (and hence limit complexity); - adding a large suite of Pydantic models that can describe complex under-construction queries, allowing us to send them over the wire in RemoteButler. Because QueryDriver doesn't have any concrete implementations yet, this change means Butler._query no longer works at (previously it delegated to the old registry.queries system). A QueryDriver implementation for DirectButler has been largely implemented on another branch and will be added later. For now, the only tests are those that rely on a mocked QueryDriver (or don't require one at all). These are in two files: - test_query_interfaces.py tests the public interface objects, including the semi-public Pydantic models; - test_query_utilities.py tests some utility classes (ColumnSet and OverlapsVisitor) that are expected to be used by all driver implementations to establish some behavioral invariants. There is already substantial duplication with code in lsst.daf.butler.registry.queries, and that will get worse when a direct-SQL driver class is added. Eventually the plan is to retire almost all of lsst.daf.butler.registry.queries (except the string-expression parser, which we'll move later) making the public registry query interfaces delegate to lsst.daf.butler.queries instead, but that will require both getting the latter fully functional and RFC'ing the removal of some things we have no intention of doing in the new system. --- python/lsst/daf/butler/__init__.py | 12 +- python/lsst/daf/butler/_butler.py | 109 +- python/lsst/daf/butler/_query.py | 254 --- python/lsst/daf/butler/_query_results.py | 728 ------ python/lsst/daf/butler/direct_butler.py | 115 +- python/lsst/daf/butler/direct_query.py | 127 -- .../lsst/daf/butler/direct_query_results.py | 298 --- python/lsst/daf/butler/queries/__init__.py | 32 + python/lsst/daf/butler/queries/_base.py | 287 +++ .../queries/_data_coordinate_query_results.py | 101 + .../butler/queries/_dataset_query_results.py | 236 ++ .../_dimension_record_query_results.py | 121 + python/lsst/daf/butler/queries/_query.py | 652 ++++++ .../lsst/daf/butler/queries/convert_args.py | 263 +++ python/lsst/daf/butler/queries/driver.py | 458 ++++ .../daf/butler/queries/expression_factory.py | 502 +++++ python/lsst/daf/butler/queries/overlaps.py | 466 ++++ .../lsst/daf/butler/queries/result_specs.py | 262 +++ .../lsst/daf/butler/queries/tree/__init__.py | 40 + python/lsst/daf/butler/queries/tree/_base.py | 188 ++ .../butler/queries/tree/_column_expression.py | 278 +++ .../butler/queries/tree/_column_literal.py | 372 +++ .../butler/queries/tree/_column_reference.py | 173 ++ .../daf/butler/queries/tree/_column_set.py | 357 +++ .../daf/butler/queries/tree/_predicate.py | 629 ++++++ .../daf/butler/queries/tree/_query_tree.py | 314 +++ python/lsst/daf/butler/queries/visitors.py | 540 +++++ .../butler/remote_butler/_remote_butler.py | 54 +- python/lsst/daf/butler/tests/hybrid_butler.py | 6 +- tests/test_query_interface.py | 2004 +++++++++++++++++ tests/test_query_utilities.py | 470 ++++ 31 files changed, 8840 insertions(+), 1608 deletions(-) delete mode 100644 python/lsst/daf/butler/_query.py delete mode 100644 python/lsst/daf/butler/_query_results.py delete mode 100644 python/lsst/daf/butler/direct_query.py delete mode 100644 python/lsst/daf/butler/direct_query_results.py create mode 100644 python/lsst/daf/butler/queries/__init__.py create mode 100644 python/lsst/daf/butler/queries/_base.py create mode 100644 python/lsst/daf/butler/queries/_data_coordinate_query_results.py create mode 100644 python/lsst/daf/butler/queries/_dataset_query_results.py create mode 100644 python/lsst/daf/butler/queries/_dimension_record_query_results.py create mode 100644 python/lsst/daf/butler/queries/_query.py create mode 100644 python/lsst/daf/butler/queries/convert_args.py create mode 100644 python/lsst/daf/butler/queries/driver.py create mode 100644 python/lsst/daf/butler/queries/expression_factory.py create mode 100644 python/lsst/daf/butler/queries/overlaps.py create mode 100644 python/lsst/daf/butler/queries/result_specs.py create mode 100644 python/lsst/daf/butler/queries/tree/__init__.py create mode 100644 python/lsst/daf/butler/queries/tree/_base.py create mode 100644 python/lsst/daf/butler/queries/tree/_column_expression.py create mode 100644 python/lsst/daf/butler/queries/tree/_column_literal.py create mode 100644 python/lsst/daf/butler/queries/tree/_column_reference.py create mode 100644 python/lsst/daf/butler/queries/tree/_column_set.py create mode 100644 python/lsst/daf/butler/queries/tree/_predicate.py create mode 100644 python/lsst/daf/butler/queries/tree/_query_tree.py create mode 100644 python/lsst/daf/butler/queries/visitors.py create mode 100644 tests/test_query_interface.py create mode 100644 tests/test_query_utilities.py diff --git a/python/lsst/daf/butler/__init__.py b/python/lsst/daf/butler/__init__.py index f15193ce87..b7d86f00ba 100644 --- a/python/lsst/daf/butler/__init__.py +++ b/python/lsst/daf/butler/__init__.py @@ -60,8 +60,6 @@ from ._named import * from ._quantum import * from ._quantum_backed import * -from ._query import * -from ._query_results import * from ._storage_class import * from ._storage_class_delegate import * from ._timespan import * @@ -80,6 +78,16 @@ # Only lift 'Progress' from 'progess'; the module is imported as-is above from .progress import Progress +# Only import the main public symbols from queries +from .queries import ( + ChainedDatasetQueryResults, + DataCoordinateQueryResults, + DatasetQueryResults, + DimensionRecordQueryResults, + Query, + SingleTypeDatasetQueryResults, +) + # Do not import or lift symbols from 'server' or 'server_models'. # Import the registry subpackage directly for other symbols. from .registry import ( diff --git a/python/lsst/daf/butler/_butler.py b/python/lsst/daf/butler/_butler.py index b3c3b07176..bf2cdbf079 100644 --- a/python/lsst/daf/butler/_butler.py +++ b/python/lsst/daf/butler/_butler.py @@ -32,16 +32,19 @@ from abc import abstractmethod from collections.abc import Collection, Iterable, Mapping, Sequence from contextlib import AbstractContextManager +from types import EllipsisType from typing import TYPE_CHECKING, Any, TextIO from lsst.resources import ResourcePath, ResourcePathExpression from lsst.utils import doImportType +from lsst.utils.iteration import ensure_iterable from lsst.utils.logging import getLogger from ._butler_config import ButlerConfig, ButlerType from ._butler_instance_options import ButlerInstanceOptions from ._butler_repo_index import ButlerRepoIndex from ._config import Config, ConfigSubset +from ._exceptions import EmptyQueryResultError from ._limited_butler import LimitedButler from .datastore import Datastore from .dimensions import DimensionConfig @@ -54,12 +57,12 @@ from ._dataset_type import DatasetType from ._deferredDatasetHandle import DeferredDatasetHandle from ._file_dataset import FileDataset - from ._query import Query from ._storage_class import StorageClass from ._timespan import Timespan from .datastore import DatasetRefURIs from .dimensions import DataCoordinate, DataId, DimensionGroup, DimensionRecord - from .registry import CollectionArgType, Registry + from .queries import Query + from .registry import Registry from .transfers import RepoExportContext _LOG = getLogger(__name__) @@ -1428,7 +1431,6 @@ def _query(self) -> AbstractContextManager[Query]: """ raise NotImplementedError() - @abstractmethod def _query_data_ids( self, dimensions: DimensionGroup | Iterable[str] | str, @@ -1436,7 +1438,7 @@ def _query_data_ids( data_id: DataId | None = None, where: str = "", bind: Mapping[str, Any] | None = None, - expanded: bool = False, + with_dimension_records: bool = False, order_by: Iterable[str] | str | None = None, limit: int | None = None, offset: int = 0, @@ -1466,7 +1468,7 @@ def _query_data_ids( Values of collection type can be expanded in some cases; see :ref:`daf_butler_dimension_expressions_identifiers` for more information. - expanded : `bool`, optional + with_dimension_records : `bool`, optional If `True` (default is `False`) then returned data IDs will have dimension records. order_by : `~collections.abc.Iterable` [`str`] or `str`, optional @@ -1511,19 +1513,32 @@ def _query_data_ids( Raised when the arguments are incompatible, e.g. ``offset`` is specified, but ``limit`` is not. """ - raise NotImplementedError() + if data_id is None: + data_id = DataCoordinate.make_empty(self.dimensions) + with self._query() as query: + result = ( + query.where(data_id, where, bind=bind, **kwargs) + .data_ids(dimensions) + .order_by(*ensure_iterable(order_by)) + .limit(limit, offset) + ) + if with_dimension_records: + result = result.with_dimension_records() + data_ids = list(result) + if explain and not data_ids: + raise EmptyQueryResultError(list(result.explain_no_results())) + return data_ids - @abstractmethod def _query_datasets( self, - dataset_type: Any, - collections: CollectionArgType | None = None, + dataset_type: str | Iterable[str] | DatasetType | Iterable[DatasetType] | EllipsisType, + collections: str | Iterable[str] | None = None, *, find_first: bool = True, data_id: DataId | None = None, where: str = "", bind: Mapping[str, Any] | None = None, - expanded: bool = False, + with_dimension_records: bool = False, explain: bool = True, **kwargs: Any, ) -> list[DatasetRef]: @@ -1533,17 +1548,13 @@ def _query_datasets( ---------- dataset_type : dataset type expression An expression that fully or partially identifies the dataset types - to be queried. Allowed types include `DatasetType`, `str`, - `re.Pattern`, and iterables thereof. The special value ``...`` can - be used to query all dataset types. See - :ref:`daf_butler_dataset_type_expressions` for more information. + to be queried. Allowed types include `DatasetType`, `str`, and + iterables thereof. The special value ``...`` can be used to query + all dataset types. See :ref:`daf_butler_dataset_type_expressions` + for more information. collections : collection expression, optional - An expression that identifies the collections to search, such as a - `str` (for full matches or partial matches via globs), `re.Pattern` - (for partial matches), or iterable thereof. ``...`` can be used to - search all collections (actually just all `~CollectionType.RUN` - collections, because this will still find all datasets). - If not provided, the default collections are used. See + A collection name or iterable of collection names to search. If not + provided, the default collections are used. See :ref:`daf_butler_collection_expressions` for more information. find_first : `bool`, optional If `True` (default), for each result data ID, only yield one @@ -1552,20 +1563,20 @@ def _query_datasets( order of ``collections`` passed in). If `True`, ``collections`` must not contain regular expressions and may not be ``...``. data_id : `dict` or `DataCoordinate`, optional - A data ID whose key-value pairs are used as equality constraints - in the query. + A data ID whose key-value pairs are used as equality constraints in + the query. where : `str`, optional - A string expression similar to a SQL WHERE clause. May involve - any column of a dimension table or (as a shortcut for the primary - key column of a dimension table) dimension name. See + A string expression similar to a SQL WHERE clause. May involve any + column of a dimension table or (as a shortcut for the primary key + column of a dimension table) dimension name. See :ref:`daf_butler_dimension_expressions` for more information. bind : `~collections.abc.Mapping`, optional Mapping containing literal values that should be injected into the - ``where`` expression, keyed by the identifiers they replace. - Values of collection type can be expanded in some cases; see + ``where`` expression, keyed by the identifiers they replace. Values + of collection type can be expanded in some cases; see :ref:`daf_butler_dimension_expressions_identifiers` for more information. - expanded : `bool`, optional + with_dimension_records : `bool`, optional If `True` (default is `False`) then returned data IDs will have dimension records. explain : `bool`, optional @@ -1586,8 +1597,8 @@ def _query_datasets( IDs are guaranteed to include values for all implied dimensions (i.e. `DataCoordinate.hasFull` will return `True`), but will not include dimension records (`DataCoordinate.hasRecords` will be - `False`) unless `~.queries.DatasetQueryResults.expanded` is - called on the result object (which returns a new one). + `False`) unless `~.queries.DatasetQueryResults.expanded` is called + on the result object (which returns a new one). Raises ------ @@ -1609,14 +1620,26 @@ def _query_datasets( Notes ----- - When multiple dataset types are queried in a single call, the - results of this operation are equivalent to querying for each dataset - type separately in turn, and no information about the relationships - between datasets of different types is included. + When multiple dataset types are queried in a single call, the results + of this operation are equivalent to querying for each dataset type + separately in turn, and no information about the relationships between + datasets of different types is included. """ - raise NotImplementedError() + if data_id is None: + data_id = DataCoordinate.make_empty(self.dimensions) + with self._query() as query: + result = query.where(data_id, where, bind=bind, **kwargs).datasets( + dataset_type, + collections=collections, + find_first=find_first, + ) + if with_dimension_records: + result = result.with_dimension_records() + refs = list(result) + if explain and not refs: + raise EmptyQueryResultError(list(result.explain_no_results())) + return refs - @abstractmethod def _query_dimension_records( self, element: str, @@ -1691,7 +1714,19 @@ def _query_dimension_records( when ``collections`` is `None` and default butler collections are not defined. """ - raise NotImplementedError() + if data_id is None: + data_id = DataCoordinate.make_empty(self.dimensions) + with self._query() as query: + result = ( + query.where(data_id, where, bind=bind, **kwargs) + .dimension_records(element) + .order_by(*ensure_iterable(order_by)) + .limit(limit, offset) + ) + dimension_records = list(result) + if explain and not dimension_records: + raise EmptyQueryResultError(list(result.explain_no_results())) + return dimension_records @abstractmethod def _clone( diff --git a/python/lsst/daf/butler/_query.py b/python/lsst/daf/butler/_query.py deleted file mode 100644 index d0ad5b3858..0000000000 --- a/python/lsst/daf/butler/_query.py +++ /dev/null @@ -1,254 +0,0 @@ -# This file is part of daf_butler. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This software is dual licensed under the GNU General Public License and also -# under a 3-clause BSD license. Recipients may choose which of these licenses -# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, -# respectively. If you choose the GPL option then the following text applies -# (but note that there is still no warranty even if you opt for BSD instead): -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -from __future__ import annotations - -__all__ = ("Query",) - -from abc import ABC, abstractmethod -from collections.abc import Iterable, Mapping -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from ._query_results import DataCoordinateQueryResults, DatasetQueryResults, DimensionRecordQueryResults - from .dimensions import DataId, DimensionGroup - from .registry._registry import CollectionArgType - - -class Query(ABC): - """Interface for construction and execution of complex queries.""" - - @abstractmethod - def data_ids( - self, - dimensions: DimensionGroup | Iterable[str] | str, - *, - data_id: DataId | None = None, - where: str = "", - bind: Mapping[str, Any] | None = None, - **kwargs: Any, - ) -> DataCoordinateQueryResults: - """Query for data IDs matching user-provided criteria. - - Parameters - ---------- - dimensions : `DimensionGroup`, `str`, or \ - `~collections.abc.Iterable` [`str`] - The dimensions of the data IDs to yield, as either `DimensionGroup` - instances or `str`. Will be automatically expanded to a complete - `DimensionGroup`. - data_id : `dict` or `DataCoordinate`, optional - A data ID whose key-value pairs are used as equality constraints - in the query. - where : `str`, optional - A string expression similar to a SQL WHERE clause. May involve - any column of a dimension table or (as a shortcut for the primary - key column of a dimension table) dimension name. See - :ref:`daf_butler_dimension_expressions` for more information. - bind : `~collections.abc.Mapping`, optional - Mapping containing literal values that should be injected into the - ``where`` expression, keyed by the identifiers they replace. - Values of collection type can be expanded in some cases; see - :ref:`daf_butler_dimension_expressions_identifiers` for more - information. - **kwargs - Additional keyword arguments are forwarded to - `DataCoordinate.standardize` when processing the ``data_id`` - argument (and may be used to provide a constraining data ID even - when the ``data_id`` argument is `None`). - - Returns - ------- - dataIds : `DataCoordinateQueryResults` - Data IDs matching the given query parameters. These are guaranteed - to identify all dimensions (`DataCoordinate.hasFull` returns - `True`), but will not contain `DimensionRecord` objects - (`DataCoordinate.hasRecords` returns `False`). Call - `~DataCoordinateQueryResults.expanded` on the - returned object to fetch those (and consider using - `~DataCoordinateQueryResults.materialize` on the - returned object first if the expected number of rows is very - large). See documentation for those methods for additional - information. - - Raises - ------ - lsst.daf.butler.registry.DataIdError - Raised when ``data_id`` or keyword arguments specify unknown - dimensions or values, or when they contain inconsistent values. - lsst.daf.butler.registry.UserExpressionError - Raised when ``where`` expression is invalid. - """ - raise NotImplementedError() - - @abstractmethod - def datasets( - self, - dataset_type: Any, - collections: CollectionArgType | None = None, - *, - find_first: bool = True, - data_id: DataId | None = None, - where: str = "", - bind: Mapping[str, Any] | None = None, - **kwargs: Any, - ) -> DatasetQueryResults: - """Query for and iterate over dataset references matching user-provided - criteria. - - Parameters - ---------- - dataset_type : dataset type expression - An expression that fully or partially identifies the dataset types - to be queried. Allowed types include `DatasetType`, `str`, - `re.Pattern`, and iterables thereof. The special value ``...`` can - be used to query all dataset types. See - :ref:`daf_butler_dataset_type_expressions` for more information. - collections : collection expression, optional - An expression that identifies the collections to search, such as a - `str` (for full matches or partial matches via globs), `re.Pattern` - (for partial matches), or iterable thereof. ``...`` can be used to - search all collections (actually just all `~CollectionType.RUN` - collections, because this will still find all datasets). - If not provided, the default collections are used. See - :ref:`daf_butler_collection_expressions` for more information. - find_first : `bool`, optional - If `True` (default), for each result data ID, only yield one - `DatasetRef` of each `DatasetType`, from the first collection in - which a dataset of that dataset type appears (according to the - order of ``collections`` passed in). If `True`, ``collections`` - must not contain regular expressions and may not be ``...``. - data_id : `dict` or `DataCoordinate`, optional - A data ID whose key-value pairs are used as equality constraints - in the query. - where : `str`, optional - A string expression similar to a SQL WHERE clause. May involve - any column of a dimension table or (as a shortcut for the primary - key column of a dimension table) dimension name. See - :ref:`daf_butler_dimension_expressions` for more information. - bind : `~collections.abc.Mapping`, optional - Mapping containing literal values that should be injected into the - ``where`` expression, keyed by the identifiers they replace. - Values of collection type can be expanded in some cases; see - :ref:`daf_butler_dimension_expressions_identifiers` for more - information. - **kwargs - Additional keyword arguments are forwarded to - `DataCoordinate.standardize` when processing the ``data_id`` - argument (and may be used to provide a constraining data ID even - when the ``data_id`` argument is `None`). - - Returns - ------- - refs : `.queries.DatasetQueryResults` - Dataset references matching the given query criteria. Nested data - IDs are guaranteed to include values for all implied dimensions - (i.e. `DataCoordinate.hasFull` will return `True`), but will not - include dimension records (`DataCoordinate.hasRecords` will be - `False`) unless `~.queries.DatasetQueryResults.expanded` is - called on the result object (which returns a new one). - - Raises - ------ - lsst.daf.butler.registry.DatasetTypeExpressionError - Raised when ``dataset_type`` expression is invalid. - TypeError - Raised when the arguments are incompatible, such as when a - collection wildcard is passed when ``find_first`` is `True`, or - when ``collections`` is `None` and default butler collections are - not defined. - lsst.daf.butler.registry.DataIdError - Raised when ``data_id`` or keyword arguments specify unknown - dimensions or values, or when they contain inconsistent values. - lsst.daf.butler.registry.UserExpressionError - Raised when ``where`` expression is invalid. - - Notes - ----- - When multiple dataset types are queried in a single call, the - results of this operation are equivalent to querying for each dataset - type separately in turn, and no information about the relationships - between datasets of different types is included. - """ - raise NotImplementedError() - - @abstractmethod - def dimension_records( - self, - element: str, - *, - data_id: DataId | None = None, - where: str = "", - bind: Mapping[str, Any] | None = None, - **kwargs: Any, - ) -> DimensionRecordQueryResults: - """Query for dimension information matching user-provided criteria. - - Parameters - ---------- - element : `str` - The name of a dimension element to obtain records for. - data_id : `dict` or `DataCoordinate`, optional - A data ID whose key-value pairs are used as equality constraints - in the query. - where : `str`, optional - A string expression similar to a SQL WHERE clause. See - `queryDataIds` and :ref:`daf_butler_dimension_expressions` for more - information. - bind : `~collections.abc.Mapping`, optional - Mapping containing literal values that should be injected into the - ``where`` expression, keyed by the identifiers they replace. - Values of collection type can be expanded in some cases; see - :ref:`daf_butler_dimension_expressions_identifiers` for more - information. - **kwargs - Additional keyword arguments are forwarded to - `DataCoordinate.standardize` when processing the ``data_id`` - argument (and may be used to provide a constraining data ID even - when the ``data_id`` argument is `None`). - - Returns - ------- - records : `.queries.DimensionRecordQueryResults` - Data IDs matching the given query parameters. - - Raises - ------ - lsst.daf.butler.registry.NoDefaultCollectionError - Raised if ``collections`` is `None` and - ``self.defaults.collections`` is `None`. - lsst.daf.butler.registry.CollectionExpressionError - Raised when ``collections`` expression is invalid. - lsst.daf.butler.registry.DataIdError - Raised when ``data_id`` or keyword arguments specify unknown - dimensions or values, or when they contain inconsistent values. - lsst.daf.butler.registry.DatasetTypeExpressionError - Raised when ``datasetType`` expression is invalid. - lsst.daf.butler.registry.UserExpressionError - Raised when ``where`` expression is invalid. - """ - raise NotImplementedError() diff --git a/python/lsst/daf/butler/_query_results.py b/python/lsst/daf/butler/_query_results.py deleted file mode 100644 index 39e4b38760..0000000000 --- a/python/lsst/daf/butler/_query_results.py +++ /dev/null @@ -1,728 +0,0 @@ -# This file is part of daf_butler. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This software is dual licensed under the GNU General Public License and also -# under a 3-clause BSD license. Recipients may choose which of these licenses -# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, -# respectively. If you choose the GPL option then the following text applies -# (but note that there is still no warranty even if you opt for BSD instead): -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -from __future__ import annotations - -__all__ = ( - "DataCoordinateQueryResults", - "DatasetQueryResults", - "DimensionRecordQueryResults", - "SingleTypeDatasetQueryResults", -) - -from abc import abstractmethod -from collections.abc import Iterable, Iterator -from contextlib import AbstractContextManager -from typing import TYPE_CHECKING, Any - -from ._dataset_ref import DatasetRef -from .dimensions import DataCoordinate, DimensionRecord - -if TYPE_CHECKING: - from ._dataset_type import DatasetType - from .dimensions import DimensionElement, DimensionGroup - - -class DataCoordinateQueryResults(Iterable[DataCoordinate]): - """An interface for objects that represent the results of queries for - data IDs. - """ - - @property - @abstractmethod - def dimensions(self) -> DimensionGroup: - """The dimensions of the data IDs returned by this query.""" - raise NotImplementedError() - - @abstractmethod - def has_full(self) -> bool: - """Indicate if all data IDs in this iterable identify all dimensions, - not just required dimensions. - - Returns - ------- - state : `bool` - If `True`, ``all(d.hasFull() for d in self)`` is guaranteed. - If `False`, no guarantees are made. - """ - raise NotImplementedError() - - @abstractmethod - def has_records(self) -> bool: - """Return whether all data IDs in this iterable contain records. - - Returns - ------- - state : `bool` - If `True`, ``all(d.hasRecords() for d in self)`` is guaranteed. - If `False`, no guarantees are made. - """ - raise NotImplementedError() - - @abstractmethod - def materialize(self) -> AbstractContextManager[DataCoordinateQueryResults]: - """Insert this query's results into a temporary table. - - Returns - ------- - context : `typing.ContextManager` [ `DataCoordinateQueryResults` ] - A context manager that ensures the temporary table is created and - populated in ``__enter__`` (returning a results object backed by - that table), and dropped in ``__exit__``. If ``self`` is already - materialized, the context manager may do nothing (reflecting the - fact that an outer context manager should already take care of - everything else). - - Notes - ----- - When using a very large result set to perform multiple queries (e.g. - multiple calls to `subset` with different arguments, or even a single - call to `expanded`), it may be much more efficient to start by - materializing the query and only then performing the follow up queries. - It may also be less efficient, depending on how well database engine's - query optimizer can simplify those particular follow-up queries and - how efficiently it caches query results even when the are not - explicitly inserted into a temporary table. See `expanded` and - `subset` for examples. - """ - raise NotImplementedError() - - @abstractmethod - def expanded(self) -> DataCoordinateQueryResults: - """Return a results object for which `has_records` returns `True`. - - This method may involve actually executing database queries to fetch - `DimensionRecord` objects. - - Returns - ------- - results : `DataCoordinateQueryResults` - A results object for which `has_records` returns `True`. May be - ``self`` if that is already the case. - - Notes - ----- - For very result sets, it may be much more efficient to call - `materialize` before calling `expanded`, to avoid performing the - original query multiple times (as a subquery) in the follow-up queries - that fetch dimension records. For example:: - - with butler.query() as query: - with query.data_ids(...).materialize() as tempDataIds: - dataIdsWithRecords = tempDataIds.expanded() - for dataId in dataIdsWithRecords: - ... - """ - raise NotImplementedError() - - @abstractmethod - def subset( - self, - dimensions: DimensionGroup | Iterable[str] | None = None, - *, - unique: bool = False, - ) -> DataCoordinateQueryResults: - """Return a results object containing a subset of the dimensions of - this one, and/or a unique near-subset of its rows. - - This method may involve actually executing database queries to fetch - `DimensionRecord` objects. - - Parameters - ---------- - dimensions : `DimensionGroup` or \ - `~collections.abc.Iterable` [ `str`], optional - Dimensions to include in the new results object. If `None`, - ``self.dimensions`` is used. - unique : `bool`, optional - If `True` (`False` is default), the query should only return unique - data IDs. This is implemented in the database; to obtain unique - results via Python-side processing (which may be more efficient in - some cases), use `toSet` to construct a `DataCoordinateSet` from - this results object instead. - - Returns - ------- - results : `DataCoordinateQueryResults` - A results object corresponding to the given criteria. May be - ``self`` if it already qualifies. - - Raises - ------ - ValueError - Raised when ``dimensions`` is not a subset of the dimensions in - this result. - - Notes - ----- - This method can only return a "near-subset" of the original result rows - in general because of subtleties in how spatial overlaps are - implemented; see `Query.projected` for more information. - - When calling `subset` multiple times on the same very large result set, - it may be much more efficient to call `materialize` first. For - example:: - - dimensions1 = DimensionGroup(...) - dimensions2 = DimensionGroup(...) - with butler.query(...)as query: - with query.data_ids(...).materialize() as data_ids: - for dataId1 in data_ids.subset(dimensions1, unique=True): - ... - for dataId2 in data_ids.subset(dimensions2, unique=True): - ... - """ - raise NotImplementedError() - - @abstractmethod - def find_datasets( - self, dataset_type: DatasetType | str, collections: Any, *, find_first: bool = True - ) -> DatasetQueryResults: - """Find datasets using the data IDs identified by this query. - - Parameters - ---------- - dataset_type : `DatasetType` or `str` - Dataset type or the name of one to search for. Must have - dimensions that are a subset of ``self.dimensions``. - collections : `Any` - An expression that fully or partially identifies the collections - to search for the dataset, such as a `str`, `re.Pattern`, or - iterable thereof. ``...`` can be used to return all collections. - See :ref:`daf_butler_collection_expressions` for more information. - find_first : `bool`, optional - If `True` (default), for each result data ID, only yield one - `DatasetRef`, from the first collection in which a dataset of that - dataset type appears (according to the order of ``collections`` - passed in). If `True`, ``collections`` must not contain regular - expressions and may not be ``...``. - - Returns - ------- - datasets : `ParentDatasetQueryResults` - A lazy-evaluation object representing dataset query results, - iterable over `DatasetRef` objects. If ``self.has_records()``, all - nested data IDs in those dataset references will have records as - well. - - Raises - ------ - MissingDatasetTypeError - Raised if the given dataset type is not registered. - """ - raise NotImplementedError() - - @abstractmethod - def find_related_datasets( - self, - dataset_type: DatasetType | str, - collections: Any, - *, - find_first: bool = True, - dimensions: DimensionGroup | Iterable[str] | None = None, - ) -> Iterable[tuple[DataCoordinate, DatasetRef]]: - """Find datasets using the data IDs identified by this query, and - return them along with the original data IDs. - - This is a variant of `find_datasets` that is often more useful when - the target dataset type does not have all of the dimensions of the - original data ID query, as is generally the case with calibration - lookups. - - Parameters - ---------- - dataset_type : `DatasetType` or `str` - Dataset type or the name of one to search for. Must have - dimensions that are a subset of ``self.dimensions``. - collections : `Any` - An expression that fully or partially identifies the collections - to search for the dataset, such as a `str`, `re.Pattern`, or - iterable thereof. ``...`` can be used to return all collections. - See :ref:`daf_butler_collection_expressions` for more information. - find_first : `bool`, optional - If `True` (default), for each data ID in ``self``, only yield one - `DatasetRef`, from the first collection in which a dataset of that - dataset type appears (according to the order of ``collections`` - passed in). If `True`, ``collections`` must not contain regular - expressions and may not be ``...``. Note that this is not the - same as yielding one `DatasetRef` for each yielded data ID if - ``dimensions`` is not `None`. - dimensions : `DimensionGroup`, or \ - `~collections.abc.Iterable` [ `str` ], optional - The dimensions of the data IDs returned. Must be a subset of - ``self.dimensions``. - - Returns - ------- - pairs : `~collections.abc.Iterable` [ `tuple` [ `DataCoordinate`, \ - `DatasetRef` ] ] - An iterable of (data ID, dataset reference) pairs. - - Raises - ------ - MissingDatasetTypeError - Raised if the given dataset type is not registered. - """ - raise NotImplementedError() - - @abstractmethod - def count(self, *, exact: bool = True, discard: bool = False) -> int: - """Count the number of rows this query would return. - - Parameters - ---------- - exact : `bool`, optional - If `True`, run the full query and perform post-query filtering if - needed to account for that filtering in the count. If `False`, the - result may be an upper bound. - discard : `bool`, optional - If `True`, compute the exact count even if it would require running - the full query and then throwing away the result rows after - counting them. If `False`, this is an error, as the user would - usually be better off executing the query first to fetch its rows - into a new query (or passing ``exact=False``). Ignored if - ``exact=False``. - - Returns - ------- - count : `int` - The number of rows the query would return, or an upper bound if - ``exact=False``. - - Notes - ----- - This counts the number of rows returned, not the number of unique rows - returned, so even with ``exact=True`` it may provide only an upper - bound on the number of *deduplicated* result rows. - """ - raise NotImplementedError() - - @abstractmethod - def any(self, *, execute: bool = True, exact: bool = True) -> bool: - """Test whether this query returns any results. - - Parameters - ---------- - execute : `bool`, optional - If `True`, execute at least a ``LIMIT 1`` query if it cannot be - determined prior to execution that the query would return no rows. - exact : `bool`, optional - If `True`, run the full query and perform post-query filtering if - needed, until at least one result row is found. If `False`, the - returned result does not account for post-query filtering, and - hence may be `True` even when all result rows would be filtered - out. - - Returns - ------- - any : `bool` - `True` if the query would (or might, depending on arguments) yield - result rows. `False` if it definitely would not. - """ - raise NotImplementedError() - - @abstractmethod - def explain_no_results(self, execute: bool = True) -> Iterable[str]: - """Return human-readable messages that may help explain why the query - yields no results. - - Parameters - ---------- - execute : `bool`, optional - If `True` (default) execute simplified versions (e.g. ``LIMIT 1``) - of aspects of the tree to more precisely determine where rows were - filtered out. - - Returns - ------- - messages : `~collections.abc.Iterable` [ `str` ] - String messages that describe reasons the query might not yield any - results. - """ - raise NotImplementedError() - - @abstractmethod - def order_by(self, *args: str) -> DataCoordinateQueryResults: - """Make the iterator return ordered results. - - Parameters - ---------- - *args : `str` - Names of the columns/dimensions to use for ordering. Column name - can be prefixed with minus (``-``) to use descending ordering. - - Returns - ------- - result : `DataCoordinateQueryResults` - Returns ``self`` instance which is updated to return ordered - result. - - Notes - ----- - This method modifies the iterator in place and returns the same - instance to support method chaining. - """ - raise NotImplementedError() - - @abstractmethod - def limit(self, limit: int | None = None, offset: int = 0) -> DataCoordinateQueryResults: - """Make the iterator return limited number of records. - - Parameters - ---------- - limit : `int` or `None`, optional - Upper limit on the number of returned records. `None` (default) is - no limit. - offset : `int`, optional - The number of records to skip before returning at most ``limit`` - records. - - Returns - ------- - result : `DataCoordinateQueryResults` - Returns ``self`` instance which is updated to return limited set - of records. - - Notes - ----- - This method modifies the iterator in place and returns the same - instance to support method chaining. Normally this method is used - together with `order_by` method. - """ - raise NotImplementedError() - - -class DatasetQueryResults(Iterable[DatasetRef]): - """An interface for objects that represent the results of queries for - datasets. - """ - - @abstractmethod - def by_dataset_type(self) -> Iterator[SingleTypeDatasetQueryResults]: - """Group results by dataset type. - - Returns - ------- - iter : `~collections.abc.Iterator` [ `SingleTypeDatasetQueryResults` ] - An iterator over `DatasetQueryResults` instances that are each - responsible for a single dataset type. - """ - raise NotImplementedError() - - @abstractmethod - def materialize(self) -> AbstractContextManager[DatasetQueryResults]: - """Insert this query's results into a temporary table. - - Returns - ------- - context : `typing.ContextManager` [ `DatasetQueryResults` ] - A context manager that ensures the temporary table is created and - populated in ``__enter__`` (returning a results object backed by - that table), and dropped in ``__exit__``. If ``self`` is already - materialized, the context manager may do nothing (reflecting the - fact that an outer context manager should already take care of - everything else). - """ - raise NotImplementedError() - - @abstractmethod - def expanded(self) -> DatasetQueryResults: - """Return a `DatasetQueryResults` for which `DataCoordinate.hasRecords` - returns `True` for all data IDs in returned `DatasetRef` objects. - - Returns - ------- - expanded : `DatasetQueryResults` - Either a new `DatasetQueryResults` instance or ``self``, if it is - already expanded. - - Notes - ----- - As with `DataCoordinateQueryResults.expanded`, it may be more efficient - to call `materialize` before expanding data IDs for very large result - sets. - """ - raise NotImplementedError() - - @abstractmethod - def count(self, *, exact: bool = True, discard: bool = False) -> int: - """Count the number of rows this query would return. - - Parameters - ---------- - exact : `bool`, optional - If `True`, run the full query and perform post-query filtering if - needed to account for that filtering in the count. If `False`, the - result may be an upper bound. - discard : `bool`, optional - If `True`, compute the exact count even if it would require running - the full query and then throwing away the result rows after - counting them. If `False`, this is an error, as the user would - usually be better off executing the query first to fetch its rows - into a new query (or passing ``exact=False``). Ignored if - ``exact=False``. - - Returns - ------- - count : `int` - The number of rows the query would return, or an upper bound if - ``exact=False``. - - Notes - ----- - This counts the number of rows returned, not the number of unique rows - returned, so even with ``exact=True`` it may provide only an upper - bound on the number of *deduplicated* result rows. - """ - raise NotImplementedError() - - @abstractmethod - def any(self, *, execute: bool = True, exact: bool = True) -> bool: - """Test whether this query returns any results. - - Parameters - ---------- - execute : `bool`, optional - If `True`, execute at least a ``LIMIT 1`` query if it cannot be - determined prior to execution that the query would return no rows. - exact : `bool`, optional - If `True`, run the full query and perform post-query filtering if - needed, until at least one result row is found. If `False`, the - returned result does not account for post-query filtering, and - hence may be `True` even when all result rows would be filtered - out. - - Returns - ------- - any : `bool` - `True` if the query would (or might, depending on arguments) yield - result rows. `False` if it definitely would not. - """ - raise NotImplementedError() - - @abstractmethod - def explain_no_results(self, execute: bool = True) -> Iterable[str]: - """Return human-readable messages that may help explain why the query - yields no results. - - Parameters - ---------- - execute : `bool`, optional - If `True` (default) execute simplified versions (e.g. ``LIMIT 1``) - of aspects of the tree to more precisely determine where rows were - filtered out. - - Returns - ------- - messages : `~collections.abc.Iterable` [ `str` ] - String messages that describe reasons the query might not yield any - results. - """ - raise NotImplementedError() - - -class SingleTypeDatasetQueryResults(DatasetQueryResults): - """An object that represents results from a query for datasets with a - single parent `DatasetType`. - """ - - @abstractmethod - def materialize(self) -> AbstractContextManager[SingleTypeDatasetQueryResults]: - # Docstring inherited from DatasetQueryResults. - raise NotImplementedError() - - @property - @abstractmethod - def dataset_type(self) -> DatasetType: - """The parent dataset type for all datasets in this iterable - (`DatasetType`). - """ - raise NotImplementedError() - - @property - @abstractmethod - def data_ids(self) -> DataCoordinateQueryResults: - """A lazy-evaluation object representing a query for just the data - IDs of the datasets that would be returned by this query - (`DataCoordinateQueryResults`). - - The returned object is not in general `zip`-iterable with ``self``; - it may be in a different order or have (or not have) duplicates. - """ - raise NotImplementedError() - - def expanded(self) -> SingleTypeDatasetQueryResults: - # Docstring inherited from DatasetQueryResults. - raise NotImplementedError() - - -class DimensionRecordQueryResults(Iterable[DimensionRecord]): - """An interface for objects that represent the results of queries for - dimension records. - """ - - @property - @abstractmethod - def element(self) -> DimensionElement: - """Dimension element for this result (`DimensionElement`).""" - raise NotImplementedError() - - @abstractmethod - def run(self) -> DimensionRecordQueryResults: - """Execute the query and return an instance with data held in memory. - - Returns - ------- - result : `DimensionRecordQueryResults` - Query results, may return ``self`` if it has all data in memory - already. - """ - raise NotImplementedError() - - @abstractmethod - def count(self, *, exact: bool = True, discard: bool = False) -> int: - """Count the number of rows this query would return. - - Parameters - ---------- - exact : `bool`, optional - If `True`, run the full query and perform post-query filtering if - needed to account for that filtering in the count. If `False`, the - result may be an upper bound. - discard : `bool`, optional - If `True`, compute the exact count even if it would require running - the full query and then throwing away the result rows after - counting them. If `False`, this is an error, as the user would - usually be better off executing the query first to fetch its rows - into a new query (or passing ``exact=False``). Ignored if - ``exact=False``. - - Returns - ------- - count : `int` - The number of rows the query would return, or an upper bound if - ``exact=False``. - - Notes - ----- - This counts the number of rows returned, not the number of unique rows - returned, so even with ``exact=True`` it may provide only an upper - bound on the number of *deduplicated* result rows. - """ - raise NotImplementedError() - - @abstractmethod - def any(self, *, execute: bool = True, exact: bool = True) -> bool: - """Test whether this query returns any results. - - Parameters - ---------- - execute : `bool`, optional - If `True`, execute at least a ``LIMIT 1`` query if it cannot be - determined prior to execution that the query would return no rows. - exact : `bool`, optional - If `True`, run the full query and perform post-query filtering if - needed, until at least one result row is found. If `False`, the - returned result does not account for post-query filtering, and - hence may be `True` even when all result rows would be filtered - out. - - Returns - ------- - any : `bool` - `True` if the query would (or might, depending on arguments) yield - result rows. `False` if it definitely would not. - """ - raise NotImplementedError() - - @abstractmethod - def order_by(self, *args: str) -> DimensionRecordQueryResults: - """Make the iterator return ordered result. - - Parameters - ---------- - *args : `str` - Names of the columns/dimensions to use for ordering. Column name - can be prefixed with minus (``-``) to use descending ordering. - - Returns - ------- - result : `DimensionRecordQueryResults` - Returns ``self`` instance which is updated to return ordered - result. - - Notes - ----- - This method can modify the iterator in place and return the same - instance. - """ - raise NotImplementedError() - - @abstractmethod - def limit(self, limit: int | None = None, offset: int = 0) -> DimensionRecordQueryResults: - """Make the iterator return limited number of records. - - Parameters - ---------- - limit : `int` or `None`, optional - Upper limit on the number of returned records. - offset : `int`, optional - The number of records to skip before returning at most ``limit`` - records. - - Returns - ------- - result : `DimensionRecordQueryResults` - Returns ``self`` instance which is updated to return limited set of - records. - - Notes - ----- - This method can modify the iterator in place and return the same - instance. Normally this method is used together with `order_by` method. - """ - raise NotImplementedError() - - @abstractmethod - def explain_no_results(self, execute: bool = True) -> Iterable[str]: - """Return human-readable messages that may help explain why the query - yields no results. - - Parameters - ---------- - execute : `bool`, optional - If `True` (default) execute simplified versions (e.g. ``LIMIT 1``) - of aspects of the tree to more precisely determine where rows were - filtered out. - - Returns - ------- - messages : `~collections.abc.Iterable` [ `str` ] - String messages that describe reasons the query might not yield any - results. - """ - raise NotImplementedError() diff --git a/python/lsst/daf/butler/direct_butler.py b/python/lsst/daf/butler/direct_butler.py index 36e2146641..a27c2e668e 100644 --- a/python/lsst/daf/butler/direct_butler.py +++ b/python/lsst/daf/butler/direct_butler.py @@ -43,12 +43,11 @@ import os import warnings from collections import Counter, defaultdict -from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence +from collections.abc import Iterable, Iterator, MutableMapping, Sequence from typing import TYPE_CHECKING, Any, ClassVar, TextIO, cast from lsst.resources import ResourcePath, ResourcePathExpression from lsst.utils.introspection import get_class_of -from lsst.utils.iteration import ensure_iterable from lsst.utils.logging import VERBOSE, getLogger from sqlalchemy.exc import IntegrityError @@ -59,15 +58,15 @@ from ._dataset_ref import DatasetRef from ._dataset_type import DatasetType from ._deferredDatasetHandle import DeferredDatasetHandle -from ._exceptions import EmptyQueryResultError, ValidationError +from ._exceptions import ValidationError from ._limited_butler import LimitedButler from ._registry_shim import RegistryShim from ._storage_class import StorageClass, StorageClassFactory from ._timespan import Timespan from .datastore import Datastore, NullDatastore from .dimensions import DataCoordinate, Dimension -from .direct_query import DirectQuery from .progress import Progress +from .queries import Query from .registry import ( CollectionType, ConflictingDefinitionError, @@ -85,17 +84,9 @@ from ._dataset_ref import DatasetId from ._file_dataset import FileDataset - from ._query import Query from .datastore import DatasetRefURIs - from .dimensions import ( - DataId, - DataIdValue, - DimensionElement, - DimensionGroup, - DimensionRecord, - DimensionUniverse, - ) - from .registry import CollectionArgType, Registry + from .dimensions import DataId, DataIdValue, DimensionElement, DimensionRecord, DimensionUniverse + from .registry import Registry from .transfers import RepoImportBackend _LOG = getLogger(__name__) @@ -1728,7 +1719,8 @@ def _extract_all_dimension_records_from_data_ids( ) records = source_butler.registry.queryDimensionRecords( # type: ignore - element.name, **data_id.mapping # type: ignore + element.name, + **data_id.mapping, # type: ignore ) for record in records: additional_records[record.definition].setdefault(record.dataId, record) @@ -2110,98 +2102,7 @@ def dimensions(self) -> DimensionUniverse: @contextlib.contextmanager def _query(self) -> Iterator[Query]: # Docstring inherited. - with self._caching_context(): - yield DirectQuery(self._registry) - - def _query_data_ids( - self, - dimensions: DimensionGroup | Iterable[str] | str, - *, - data_id: DataId | None = None, - where: str = "", - bind: Mapping[str, Any] | None = None, - expanded: bool = False, - order_by: Iterable[str] | str | None = None, - limit: int | None = None, - offset: int = 0, - explain: bool = True, - **kwargs: Any, - ) -> list[DataCoordinate]: - # Docstring inherited. - query = DirectQuery(self._registry) - result = query.data_ids(dimensions, data_id=data_id, where=where, bind=bind, **kwargs) - if expanded: - result = result.expanded() - if order_by: - result = result.order_by(*ensure_iterable(order_by)) - if limit is not None: - result = result.limit(limit, offset) - else: - if offset: - raise TypeError("offset is specified without limit") - data_ids = list(result) - if explain and not data_ids: - raise EmptyQueryResultError(list(result.explain_no_results())) - return data_ids - - def _query_datasets( - self, - dataset_type: Any, - collections: CollectionArgType | None = None, - *, - find_first: bool = True, - data_id: DataId | None = None, - where: str = "", - bind: Mapping[str, Any] | None = None, - expanded: bool = False, - explain: bool = True, - **kwargs: Any, - ) -> list[DatasetRef]: - # Docstring inherited. - query = DirectQuery(self._registry) - result = query.datasets( - dataset_type, - collections, - find_first=find_first, - data_id=data_id, - where=where, - bind=bind, - **kwargs, - ) - if expanded: - result = result.expanded() - refs = list(result) - if explain and not refs: - raise EmptyQueryResultError(list(result.explain_no_results())) - return refs - - def _query_dimension_records( - self, - element: str, - *, - data_id: DataId | None = None, - where: str = "", - bind: Mapping[str, Any] | None = None, - order_by: Iterable[str] | str | None = None, - limit: int | None = None, - offset: int = 0, - explain: bool = True, - **kwargs: Any, - ) -> list[DimensionRecord]: - # Docstring inherited. - query = DirectQuery(self._registry) - result = query.dimension_records(element, data_id=data_id, where=where, bind=bind, **kwargs) - if order_by: - result = result.order_by(*ensure_iterable(order_by)) - if limit is not None: - result = result.limit(limit, offset) - else: - if offset: - raise TypeError("offset is specified without limit") - data_ids = list(result) - if explain and not data_ids: - raise EmptyQueryResultError(list(result.explain_no_results())) - return data_ids + raise NotImplementedError("TODO DM-41159") def _preload_cache(self) -> None: """Immediately load caches that are used for common operations.""" diff --git a/python/lsst/daf/butler/direct_query.py b/python/lsst/daf/butler/direct_query.py deleted file mode 100644 index 374f8dddbd..0000000000 --- a/python/lsst/daf/butler/direct_query.py +++ /dev/null @@ -1,127 +0,0 @@ -# This file is part of daf_butler. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This software is dual licensed under the GNU General Public License and also -# under a 3-clause BSD license. Recipients may choose which of these licenses -# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, -# respectively. If you choose the GPL option then the following text applies -# (but note that there is still no warranty even if you opt for BSD instead): -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -from __future__ import annotations - -__all__ = ["DirectQuery"] - -from collections.abc import Iterable, Mapping -from typing import TYPE_CHECKING, Any - -from ._query import Query -from .direct_query_results import ( - DirectDataCoordinateQueryResults, - DirectDatasetQueryResults, - DirectDimensionRecordQueryResults, - DirectSingleTypeDatasetQueryResults, -) -from .registry import queries as registry_queries -from .registry.sql_registry import SqlRegistry - -if TYPE_CHECKING: - from ._query_results import DataCoordinateQueryResults, DatasetQueryResults, DimensionRecordQueryResults - from .dimensions import DataId, DimensionGroup - from .registry._registry import CollectionArgType - - -class DirectQuery(Query): - """Implementation of `Query` interface used by `DirectButler`. - - Parameters - ---------- - registry : `SqlRegistry` - The object that manages dataset metadata and relationships. - """ - - _registry: SqlRegistry - - def __init__(self, registry: SqlRegistry): - self._registry = registry - - def data_ids( - self, - dimensions: DimensionGroup | Iterable[str] | str, - *, - data_id: DataId | None = None, - where: str = "", - bind: Mapping[str, Any] | None = None, - **kwargs: Any, - ) -> DataCoordinateQueryResults: - # Docstring inherited. - registry_query_result = self._registry.queryDataIds( - dimensions, - dataId=data_id, - where=where, - bind=bind, - **kwargs, - ) - return DirectDataCoordinateQueryResults(registry_query_result) - - def datasets( - self, - dataset_type: Any, - collections: CollectionArgType | None = None, - *, - find_first: bool = True, - data_id: DataId | None = None, - where: str = "", - bind: Mapping[str, Any] | None = None, - **kwargs: Any, - ) -> DatasetQueryResults: - # Docstring inherited. - registry_query_result = self._registry.queryDatasets( - dataset_type, - collections=collections, - dataId=data_id, - where=where, - findFirst=find_first, - bind=bind, - **kwargs, - ) - if isinstance(registry_query_result, registry_queries.ParentDatasetQueryResults): - return DirectSingleTypeDatasetQueryResults(registry_query_result) - else: - return DirectDatasetQueryResults(registry_query_result) - - def dimension_records( - self, - element: str, - *, - data_id: DataId | None = None, - where: str = "", - bind: Mapping[str, Any] | None = None, - **kwargs: Any, - ) -> DimensionRecordQueryResults: - # Docstring inherited. - registry_query_result = self._registry.queryDimensionRecords( - element, - dataId=data_id, - where=where, - bind=bind, - **kwargs, - ) - return DirectDimensionRecordQueryResults(registry_query_result) diff --git a/python/lsst/daf/butler/direct_query_results.py b/python/lsst/daf/butler/direct_query_results.py deleted file mode 100644 index 04a72ff875..0000000000 --- a/python/lsst/daf/butler/direct_query_results.py +++ /dev/null @@ -1,298 +0,0 @@ -# This file is part of daf_butler. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This software is dual licensed under the GNU General Public License and also -# under a 3-clause BSD license. Recipients may choose which of these licenses -# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, -# respectively. If you choose the GPL option then the following text applies -# (but note that there is still no warranty even if you opt for BSD instead): -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -from __future__ import annotations - -__all__ = [ - "DirectDataCoordinateQueryResults", - "DirectDatasetQueryResults", - "DirectDimensionRecordQueryResults", - "DirectSingleTypeDatasetQueryResults", -] - -import contextlib -from collections.abc import Iterable, Iterator -from typing import TYPE_CHECKING, Any - -from ._query_results import ( - DataCoordinateQueryResults, - DatasetQueryResults, - DimensionRecordQueryResults, - SingleTypeDatasetQueryResults, -) -from .registry import queries as registry_queries - -if TYPE_CHECKING: - from ._dataset_ref import DatasetRef - from ._dataset_type import DatasetType - from .dimensions import DataCoordinate, DimensionElement, DimensionGroup, DimensionRecord - - -class DirectDataCoordinateQueryResults(DataCoordinateQueryResults): - """Implementation of `DataCoordinateQueryResults` using query result - obtained from registry. - - Parameters - ---------- - registry_query_result : \ - `~lsst.daf.butler.registry.queries.DataCoordinateQueryResults` - Query result from Registry. - """ - - def __init__(self, registry_query_result: registry_queries.DataCoordinateQueryResults): - self._registry_query_result = registry_query_result - - def __iter__(self) -> Iterator[DataCoordinate]: - return iter(self._registry_query_result) - - @property - def dimensions(self) -> DimensionGroup: - # Docstring inherited. - return self._registry_query_result.dimensions - - def has_full(self) -> bool: - # Docstring inherited. - return self._registry_query_result.hasFull() - - def has_records(self) -> bool: - # Docstring inherited. - return self._registry_query_result.hasRecords() - - @contextlib.contextmanager - def materialize(self) -> Iterator[DataCoordinateQueryResults]: - with self._registry_query_result.materialize() as result: - yield DirectDataCoordinateQueryResults(result) - - def expanded(self) -> DataCoordinateQueryResults: - # Docstring inherited. - if self.has_records(): - return self - return DirectDataCoordinateQueryResults(self._registry_query_result.expanded()) - - def subset( - self, - dimensions: DimensionGroup | Iterable[str] | None = None, - *, - unique: bool = False, - ) -> DataCoordinateQueryResults: - # Docstring inherited. - return DirectDataCoordinateQueryResults(self._registry_query_result.subset(dimensions, unique=unique)) - - def find_datasets( - self, - dataset_type: DatasetType | str, - collections: Any, - *, - find_first: bool = True, - ) -> DatasetQueryResults: - # Docstring inherited. - return DirectDatasetQueryResults( - self._registry_query_result.findDatasets(dataset_type, collections, findFirst=find_first) - ) - - def find_related_datasets( - self, - dataset_type: DatasetType | str, - collections: Any, - *, - find_first: bool = True, - dimensions: DimensionGroup | Iterable[str] | None = None, - ) -> Iterable[tuple[DataCoordinate, DatasetRef]]: - # Docstring inherited. - return self._registry_query_result.findRelatedDatasets( - dataset_type, collections, findFirst=find_first, dimensions=dimensions - ) - - def count(self, *, exact: bool = True, discard: bool = False) -> int: - # Docstring inherited. - return self._registry_query_result.count(exact=exact, discard=discard) - - def any(self, *, execute: bool = True, exact: bool = True) -> bool: - # Docstring inherited. - return self._registry_query_result.any(execute=execute, exact=exact) - - def explain_no_results(self, execute: bool = True) -> Iterable[str]: - # Docstring inherited. - return self._registry_query_result.explain_no_results(execute=execute) - - def order_by(self, *args: str) -> DataCoordinateQueryResults: - # Docstring inherited. - return DirectDataCoordinateQueryResults(self._registry_query_result.order_by(*args)) - - def limit(self, limit: int | None = None, offset: int = 0) -> DataCoordinateQueryResults: - # Docstring inherited. - if limit is None: - raise NotImplementedError("Offset without limit is temporarily unsupported.") - return DirectDataCoordinateQueryResults(self._registry_query_result.limit(limit, offset)) - - -class DirectDatasetQueryResults(DatasetQueryResults): - """Implementation of `DatasetQueryResults` using query result - obtained from registry. - - Parameters - ---------- - registry_query_result : \ - `~lsst.daf.butler.registry.queries.DatasetQueryResults` - Query result from Registry. - """ - - def __init__(self, registry_query_result: registry_queries.DatasetQueryResults): - self._registry_query_result = registry_query_result - - def __iter__(self) -> Iterator[DatasetRef]: - return iter(self._registry_query_result) - - def by_dataset_type(self) -> Iterator[SingleTypeDatasetQueryResults]: - # Docstring inherited. - for by_parent in self._registry_query_result.byParentDatasetType(): - yield DirectSingleTypeDatasetQueryResults(by_parent) - - @contextlib.contextmanager - def materialize(self) -> Iterator[DatasetQueryResults]: - # Docstring inherited. - with self._registry_query_result.materialize() as result: - yield DirectDatasetQueryResults(result) - - def expanded(self) -> DatasetQueryResults: - # Docstring inherited. - return DirectDatasetQueryResults(self._registry_query_result.expanded()) - - def count(self, *, exact: bool = True, discard: bool = False) -> int: - # Docstring inherited. - return self._registry_query_result.count(exact=exact, discard=discard) - - def any(self, *, execute: bool = True, exact: bool = True) -> bool: - # Docstring inherited. - return self._registry_query_result.any(execute=execute, exact=exact) - - def explain_no_results(self, execute: bool = True) -> Iterable[str]: - # Docstring inherited. - return self._registry_query_result.explain_no_results(execute=execute) - - -class DirectSingleTypeDatasetQueryResults(SingleTypeDatasetQueryResults): - """Implementation of `SingleTypeDatasetQueryResults` using query result - obtained from registry. - - Parameters - ---------- - registry_query_result : \ - `~lsst.daf.butler.registry.queries.ParentDatasetQueryResults` - Query result from Registry. - """ - - def __init__(self, registry_query_result: registry_queries.ParentDatasetQueryResults): - self._registry_query_result = registry_query_result - - def __iter__(self) -> Iterator[DatasetRef]: - return iter(self._registry_query_result) - - def by_dataset_type(self) -> Iterator[SingleTypeDatasetQueryResults]: - # Docstring inherited. - yield self - - @contextlib.contextmanager - def materialize(self) -> Iterator[SingleTypeDatasetQueryResults]: - # Docstring inherited. - with self._registry_query_result.materialize() as result: - yield DirectSingleTypeDatasetQueryResults(result) - - @property - def dataset_type(self) -> DatasetType: - # Docstring inherited. - return self._registry_query_result.parentDatasetType - - @property - def data_ids(self) -> DataCoordinateQueryResults: - # Docstring inherited. - return DirectDataCoordinateQueryResults(self._registry_query_result.dataIds) - - def expanded(self) -> SingleTypeDatasetQueryResults: - # Docstring inherited. - return DirectSingleTypeDatasetQueryResults(self._registry_query_result.expanded()) - - def count(self, *, exact: bool = True, discard: bool = False) -> int: - # Docstring inherited. - return self._registry_query_result.count(exact=exact, discard=discard) - - def any(self, *, execute: bool = True, exact: bool = True) -> bool: - # Docstring inherited. - return self._registry_query_result.any(execute=execute, exact=exact) - - def explain_no_results(self, execute: bool = True) -> Iterable[str]: - # Docstring inherited. - return self._registry_query_result.explain_no_results(execute=execute) - - -class DirectDimensionRecordQueryResults(DimensionRecordQueryResults): - """Implementation of `DimensionRecordQueryResults` using query result - obtained from registry. - - Parameters - ---------- - registry_query_result : \ - `~lsst.daf.butler.registry.queries.DimensionRecordQueryResults` - Query result from Registry. - """ - - def __init__(self, registry_query_result: registry_queries.DimensionRecordQueryResults): - self._registry_query_result = registry_query_result - - def __iter__(self) -> Iterator[DimensionRecord]: - return iter(self._registry_query_result) - - @property - def element(self) -> DimensionElement: - # Docstring inherited. - return self._registry_query_result.element - - def run(self) -> DimensionRecordQueryResults: - # Docstring inherited. - return DirectDimensionRecordQueryResults(self._registry_query_result.run()) - - def count(self, *, exact: bool = True, discard: bool = False) -> int: - # Docstring inherited. - return self._registry_query_result.count(exact=exact, discard=discard) - - def any(self, *, execute: bool = True, exact: bool = True) -> bool: - # Docstring inherited. - return self._registry_query_result.any(execute=execute, exact=exact) - - def order_by(self, *args: str) -> DimensionRecordQueryResults: - # Docstring inherited. - return DirectDimensionRecordQueryResults(self._registry_query_result.order_by(*args)) - - def limit(self, limit: int | None = None, offset: int = 0) -> DimensionRecordQueryResults: - # Docstring inherited. - if limit is None: - raise NotImplementedError("Offset without limit is temporarily unsupported.") - return DirectDimensionRecordQueryResults(self._registry_query_result.limit(limit, offset)) - - def explain_no_results(self, execute: bool = True) -> Iterable[str]: - # Docstring inherited. - return self._registry_query_result.explain_no_results(execute=execute) diff --git a/python/lsst/daf/butler/queries/__init__.py b/python/lsst/daf/butler/queries/__init__.py new file mode 100644 index 0000000000..15743f291f --- /dev/null +++ b/python/lsst/daf/butler/queries/__init__.py @@ -0,0 +1,32 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from ._base import * +from ._data_coordinate_query_results import * +from ._dataset_query_results import * +from ._dimension_record_query_results import * +from ._query import * diff --git a/python/lsst/daf/butler/queries/_base.py b/python/lsst/daf/butler/queries/_base.py new file mode 100644 index 0000000000..6b64003b0f --- /dev/null +++ b/python/lsst/daf/butler/queries/_base.py @@ -0,0 +1,287 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("QueryBase", "HomogeneousQueryBase", "CountableQueryBase", "QueryResultsBase") + +from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping, Set +from typing import Any, Self + +from ..dimensions import DataId, DimensionGroup +from .convert_args import convert_order_by_args, convert_where_args +from .driver import QueryDriver +from .expression_factory import ExpressionProxy +from .tree import OrderExpression, Predicate, QueryTree + + +class QueryBase(ABC): + """Common base class for `Query` and all `QueryResult` objects. + + This class should rarely be referenced directly; it is public only because + it provides public methods to its subclasses. + """ + + @abstractmethod + def any(self, *, execute: bool = True, exact: bool = True) -> bool: + """Test whether the query would return any rows. + + Parameters + ---------- + execute : `bool`, optional + If `True`, execute at least a ``LIMIT 1`` query if it cannot be + determined prior to execution that the query would return no rows. + exact : `bool`, optional + If `True`, run the full query and perform post-query filtering if + needed, until at least one result row is found. If `False`, the + returned result does not account for post-query filtering, and + hence may be `True` even when all result rows would be filtered + out. + + Returns + ------- + any : `bool` + `True` if the query would (or might, depending on arguments) yield + result rows. `False` if it definitely would not. + """ + raise NotImplementedError() + + @abstractmethod + def explain_no_results(self, execute: bool = True) -> Iterable[str]: + """Return human-readable messages that may help explain why the query + yields no results. + + Parameters + ---------- + execute : `bool`, optional + If `True` (default) execute simplified versions (e.g. ``LIMIT 1``) + of aspects of the tree to more precisely determine where rows were + filtered out. + + Returns + ------- + messages : `~collections.abc.Iterable` [ `str` ] + String messages that describe reasons the query might not yield any + results. + """ + raise NotImplementedError() + + @abstractmethod + def where( + self, + *args: str | Predicate | DataId, + bind: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> Self: + """Return a query with a boolean-expression filter on its rows. + + Parameters + ---------- + *args + Constraints to apply, combined with logical AND. Arguments may be + `str` expressions to parse, `Predicate` objects (these are + typically constructed via `expression_factory`) or data IDs. + bind : `~collections.abc.Mapping` + Mapping from string identifier appearing in a string expression to + a literal value that should be substituted for it. This is + recommended instead of embedding literals directly into the + expression, especially for strings, timespans, or other types where + quoting or formatting is nontrivial. + **kwargs + Data ID key value pairs that extend and override any present in + ``*args``. + + Returns + ------- + query : `QueryBase` + A new query object with the given row filters (as well as any + already present in ``self``). All row filters are combined with + logical AND. + + Notes + ----- + If an expression references a dimension or dimension element that is + not already present in the query, it will be joined in, but dataset + searches must already be joined into a query in order to reference + their fields in expressions. + + Data ID values are not checked for consistency; they are extracted from + ``args`` and then ``kwargs`` and combined, with later values overriding + earlier ones. + """ + raise NotImplementedError() + + +class HomogeneousQueryBase(QueryBase): + """Common base class for `Query` and query result classes that are + iterables with consistent dimensions throughout. + + This class should rarely be referenced directly; it is public only because + it provides public methods to its subclasses. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree): + self._driver = driver + self._tree = tree + + @property + def dimensions(self) -> DimensionGroup: + """All dimensions included in the query's columns.""" + return self._tree.dimensions + + def any(self, *, execute: bool = True, exact: bool = True) -> bool: + # Docstring inherited. + return self._driver.any(self._tree, execute=execute, exact=exact) + + def explain_no_results(self, execute: bool = True) -> Iterable[str]: + # Docstring inherited. + return self._driver.explain_no_results(self._tree, execute=execute) + + +class CountableQueryBase(QueryBase): + """Common base class for query result objects for which the number of + result rows is a well-defined concept. + + This class should rarely be referenced directly; it is public only because + it provides public methods to its subclasses. + """ + + @abstractmethod + def count(self, *, exact: bool = True, discard: bool = False) -> int: + """Count the number of rows this query would return. + + Parameters + ---------- + exact : `bool`, optional + If `True`, run the full query and perform post-query filtering if + needed to account for that filtering in the count. If `False`, the + result may be an upper bound. + discard : `bool`, optional + If `True`, compute the exact count even if it would require running + the full query and then throwing away the result rows after + counting them. If `False`, this is an error, as the user would + usually be better off executing the query first to fetch its rows + into a new query (or passing ``exact=False``). Ignored if + ``exact=False``. + + Returns + ------- + count : `int` + The number of rows the query would return, or an upper bound if + ``exact=False``. + """ + raise NotImplementedError() + + +class QueryResultsBase(HomogeneousQueryBase, CountableQueryBase): + """Common base class for query result objects with homogeneous dimensions + and countable rows. + """ + + def order_by(self, *args: str | OrderExpression | ExpressionProxy) -> Self: + """Return a new query that yields ordered results. + + Parameters + ---------- + *args : `str` + Names of the columns/dimensions to use for ordering. Column name + can be prefixed with minus (``-``) to use descending ordering. + + Returns + ------- + result : `QueryResultsBase` + An ordered version of this query results object. + + Notes + ----- + If this method is called multiple times, the new sort terms replace + the old ones. + """ + return self._copy( + self._tree, order_by=convert_order_by_args(self.dimensions, self._get_datasets(), *args) + ) + + def limit(self, limit: int | None = None, offset: int = 0) -> Self: + """Return a new query that slices its result rows positionally. + + Parameters + ---------- + limit : `int` or `None`, optional + Upper limit on the number of returned records. `None` (default) + means no limit. + offset : `int`, optional + The number of records to skip before returning at most ``limit`` + records. + + Returns + ------- + result : `QueryResultsBase` + A sliced version of this query results object. + + Notes + ----- + If this method is called multiple times, the new slice parameters + replace the old ones. Slicing always occurs after sorting, even if + `limit` is called before `order_by`. + """ + return self._copy(self._tree, limit=limit, offset=offset) + + def where( + self, + *args: str | Predicate | DataId, + bind: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> Self: + # Docstring inherited. + return self._copy( + tree=self._tree.where( + convert_where_args(self.dimensions, self._get_datasets(), *args, bind=bind, **kwargs) + ), + driver=self._driver, + ) + + @abstractmethod + def _get_datasets(self) -> Set[str]: + """Return all dataset types included in the query's result rows.""" + raise NotImplementedError() + + @abstractmethod + def _copy(self, tree: QueryTree, **kwargs: Any) -> Self: + """Return a modified copy of ``self``. + + Implementations should validate odifications, not assume they are + correct. + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/queries/_data_coordinate_query_results.py b/python/lsst/daf/butler/queries/_data_coordinate_query_results.py new file mode 100644 index 0000000000..2aeacbab18 --- /dev/null +++ b/python/lsst/daf/butler/queries/_data_coordinate_query_results.py @@ -0,0 +1,101 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("DataCoordinateQueryResults",) + +from collections.abc import Iterator +from typing import TYPE_CHECKING, Any, final + +from ..dimensions import DataCoordinate +from ._base import QueryResultsBase +from .driver import QueryDriver +from .tree import QueryTree + +if TYPE_CHECKING: + from .result_specs import DataCoordinateResultSpec + + +@final +class DataCoordinateQueryResults(QueryResultsBase): + """A query for `DataCoordinate` results. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. The + instance returned directly by the `Butler._query` entry point should be + constructed via `make_unit_query_tree`. + spec : `DataCoordinateResultSpec` + Specification of the query result rows, including output columns, + ordering, and slicing. + + Notes + ----- + This class should never be constructed directly by users; use + `Query.data_ids` instead. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree, spec: DataCoordinateResultSpec): + spec.validate_tree(tree) + super().__init__(driver, tree) + self._spec = spec + + def __iter__(self) -> Iterator[DataCoordinate]: + page = self._driver.execute(self._spec, self._tree) + yield from page.rows + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield from page.rows + + @property + def has_dimension_records(self) -> bool: + """Whether all data IDs in this iterable contain dimension records.""" + return self._spec.include_dimension_records + + def with_dimension_records(self) -> DataCoordinateQueryResults: + """Return a results object for which `has_dimension_records` is + `True`. + """ + if self.has_dimension_records: + return self + return self._copy(tree=self._tree, include_dimension_records=True) + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + # Docstring inherited. + return self._driver.count(self._tree, self._spec, exact=exact, discard=discard) + + def _copy(self, tree: QueryTree, **kwargs: Any) -> DataCoordinateQueryResults: + # Docstring inherited. + return DataCoordinateQueryResults(self._driver, tree, spec=self._spec.model_copy(update=kwargs)) + + def _get_datasets(self) -> frozenset[str]: + # Docstring inherited. + return frozenset() diff --git a/python/lsst/daf/butler/queries/_dataset_query_results.py b/python/lsst/daf/butler/queries/_dataset_query_results.py new file mode 100644 index 0000000000..3495745b7a --- /dev/null +++ b/python/lsst/daf/butler/queries/_dataset_query_results.py @@ -0,0 +1,236 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "DatasetQueryResults", + "ChainedDatasetQueryResults", + "SingleTypeDatasetQueryResults", +) + +import itertools +from abc import abstractmethod +from collections.abc import Iterable, Iterator, Mapping +from typing import TYPE_CHECKING, Any, final + +from .._dataset_ref import DatasetRef +from .._dataset_type import DatasetType +from ..dimensions import DataId +from ._base import CountableQueryBase, QueryResultsBase +from .driver import QueryDriver +from .result_specs import DataCoordinateResultSpec, DatasetRefResultSpec +from .tree import Predicate, QueryTree + +if TYPE_CHECKING: + from ._data_coordinate_query_results import DataCoordinateQueryResults + + +class DatasetQueryResults(CountableQueryBase, Iterable[DatasetRef]): + """A query for `DatasetRef` results.""" + + @abstractmethod + def by_dataset_type(self) -> Iterator[SingleTypeDatasetQueryResults]: + """Group results by dataset type. + + Returns + ------- + iter : `~collections.abc.Iterator` [ `SingleTypeDatasetQueryResults` ] + An iterator over `DatasetQueryResults` instances that are each + responsible for a single dataset type. + """ + raise NotImplementedError() + + @property + @abstractmethod + def has_dimension_records(self) -> bool: + """Whether all data IDs in this iterable contain dimension records.""" + raise NotImplementedError() + + @abstractmethod + def with_dimension_records(self) -> DatasetQueryResults: + """Return a results object for which `has_dimension_records` is + `True`. + """ + raise NotImplementedError() + + +@final +class SingleTypeDatasetQueryResults(DatasetQueryResults, QueryResultsBase): + """A query for `DatasetRef` results with a single dataset type. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. The + instance returned directly by the `Butler._query` entry point should be + constructed via `make_unit_query_tree`. + spec : `DatasetRefResultSpec` + Specification of the query result rows, including output columns, + ordering, and slicing. + + Notes + ----- + This class should never be constructed directly by users; use + `Query.datasets` instead. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree, spec: DatasetRefResultSpec): + spec.validate_tree(tree) + super().__init__(driver, tree) + self._spec = spec + + def __iter__(self) -> Iterator[DatasetRef]: + page = self._driver.execute(self._spec, self._tree) + yield from page.rows + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield from page.rows + + @property + def dataset_type(self) -> DatasetType: + # Docstring inherited. + return DatasetType(self._spec.dataset_type_name, self._spec.dimensions, self._spec.storage_class_name) + + @property + def data_ids(self) -> DataCoordinateQueryResults: + # Docstring inherited. + from ._data_coordinate_query_results import DataCoordinateQueryResults + + return DataCoordinateQueryResults( + self._driver, + tree=self._tree, + spec=DataCoordinateResultSpec.model_construct( + dimensions=self.dataset_type.dimensions.as_group(), + include_dimension_records=self._spec.include_dimension_records, + ), + ) + + @property + def has_dimension_records(self) -> bool: + # Docstring inherited. + return self._spec.include_dimension_records + + def with_dimension_records(self) -> SingleTypeDatasetQueryResults: + # Docstring inherited. + if self.has_dimension_records: + return self + return self._copy(tree=self._tree, include_dimension_records=True) + + def by_dataset_type(self) -> Iterator[SingleTypeDatasetQueryResults]: + # Docstring inherited. + return iter((self,)) + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + # Docstring inherited. + return self._driver.count(self._tree, self._spec, exact=exact, discard=discard) + + def _copy(self, tree: QueryTree, **kwargs: Any) -> SingleTypeDatasetQueryResults: + # Docstring inherited. + return SingleTypeDatasetQueryResults(self._driver, tree, self._spec.model_copy(update=kwargs)) + + def _get_datasets(self) -> frozenset[str]: + # Docstring inherited. + return frozenset({self.dataset_type.name}) + + +@final +class ChainedDatasetQueryResults(DatasetQueryResults): + """A query for `DatasetRef` results with multiple dataset types. + + Parameters + ---------- + by_dataset_type : `tuple` [ `SingleTypeDatasetQueryResults` ] + Tuple of single-dataset-type query result objects to combine. + + Notes + ----- + This class should never be constructed directly by users; use + `Query.datasets` instead. + """ + + def __init__(self, by_dataset_type: tuple[SingleTypeDatasetQueryResults, ...]): + self._by_dataset_type = by_dataset_type + + def __iter__(self) -> Iterator[DatasetRef]: + return itertools.chain.from_iterable(self._by_dataset_type) + + def by_dataset_type(self) -> Iterator[SingleTypeDatasetQueryResults]: + # Docstring inherited. + return iter(self._by_dataset_type) + + @property + def has_dimension_records(self) -> bool: + # Docstring inherited. + return all(single_type_results.has_dimension_records for single_type_results in self._by_dataset_type) + + def with_dimension_records(self) -> ChainedDatasetQueryResults: + # Docstring inherited. + return ChainedDatasetQueryResults( + tuple( + [ + single_type_results.with_dimension_records() + for single_type_results in self._by_dataset_type + ] + ) + ) + + def any(self, *, execute: bool = True, exact: bool = True) -> bool: + # Docstring inherited. + return any( + single_type_results.any(execute=execute, exact=exact) + for single_type_results in self._by_dataset_type + ) + + def explain_no_results(self, execute: bool = True) -> Iterable[str]: + # Docstring inherited. + messages: list[str] = [] + for single_type_results in self._by_dataset_type: + messages.extend(single_type_results.explain_no_results(execute=execute)) + return messages + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + # Docstring inherited. + return sum( + single_type_results.count(exact=exact, discard=discard) + for single_type_results in self._by_dataset_type + ) + + def where( + self, *args: DataId | str | Predicate, bind: Mapping[str, Any] | None = None, **kwargs: Any + ) -> ChainedDatasetQueryResults: + # Docstring inherited. + return ChainedDatasetQueryResults( + tuple( + [ + single_type_results.where(*args, bind=bind, **kwargs) + for single_type_results in self._by_dataset_type + ] + ) + ) diff --git a/python/lsst/daf/butler/queries/_dimension_record_query_results.py b/python/lsst/daf/butler/queries/_dimension_record_query_results.py new file mode 100644 index 0000000000..601b2a356e --- /dev/null +++ b/python/lsst/daf/butler/queries/_dimension_record_query_results.py @@ -0,0 +1,121 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("DimensionRecordQueryResults",) + +from collections.abc import Iterator +from typing import Any, final + +from ..dimensions import DimensionElement, DimensionRecord, DimensionRecordSet, DimensionRecordTable +from ._base import QueryResultsBase +from .driver import QueryDriver +from .result_specs import DimensionRecordResultSpec +from .tree import QueryTree + + +@final +class DimensionRecordQueryResults(QueryResultsBase): + """A query for `DimensionRecord` results. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. + The instance returned directly by the `Butler._query` entry point + should be constructed via `make_unit_query_tree`. + spec : `DimensionRecordResultSpec` + Specification of the query result rows, including output columns, + ordering, and slicing. + + Notes + ----- + This class should never be constructed directly by users; use + `Query.dimension_records` instead. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree, spec: DimensionRecordResultSpec): + spec.validate_tree(tree) + super().__init__(driver, tree) + self._spec = spec + + def __iter__(self) -> Iterator[DimensionRecord]: + page = self._driver.execute(self._spec, self._tree) + yield from page.rows + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield from page.rows + + def iter_table_pages(self) -> Iterator[DimensionRecordTable]: + """Return an iterator over individual pages of results as table-backed + collections. + + Yields + ------ + table : `DimensionRecordTable` + A table-backed collection of dimension records. + """ + page = self._driver.execute(self._spec, self._tree) + yield page.as_table() + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield page.as_table() + + def iter_set_pages(self) -> Iterator[DimensionRecordSet]: + """Return an iterator over individual pages of results as set-backed + collections. + + Yields + ------ + table : `DimensionRecordSet` + A set-backed collection of dimension records. + """ + page = self._driver.execute(self._spec, self._tree) + yield page.as_set() + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield page.as_set() + + @property + def element(self) -> DimensionElement: + # Docstring inherited. + return self._spec.element + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + # Docstring inherited. + return self._driver.count(self._tree, self._spec, exact=exact, discard=discard) + + def _copy(self, tree: QueryTree, **kwargs: Any) -> DimensionRecordQueryResults: + # Docstring inherited. + return DimensionRecordQueryResults(self._driver, tree, self._spec.model_copy(update=kwargs)) + + def _get_datasets(self) -> frozenset[str]: + # Docstring inherited. + return frozenset() diff --git a/python/lsst/daf/butler/queries/_query.py b/python/lsst/daf/butler/queries/_query.py new file mode 100644 index 0000000000..332d964a93 --- /dev/null +++ b/python/lsst/daf/butler/queries/_query.py @@ -0,0 +1,652 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("Query",) + +from collections.abc import Iterable, Mapping, Set +from types import EllipsisType +from typing import Any, final, overload + +from lsst.utils.iteration import ensure_iterable + +from .._dataset_type import DatasetType +from .._storage_class import StorageClassFactory +from ..dimensions import DataCoordinate, DataId, DataIdValue, DimensionGroup +from ..registry import DatasetTypeError, MissingDatasetTypeError +from ._base import HomogeneousQueryBase +from ._data_coordinate_query_results import DataCoordinateQueryResults +from ._dataset_query_results import ( + ChainedDatasetQueryResults, + DatasetQueryResults, + SingleTypeDatasetQueryResults, +) +from ._dimension_record_query_results import DimensionRecordQueryResults +from .convert_args import convert_where_args +from .driver import QueryDriver +from .expression_factory import ExpressionFactory +from .result_specs import DataCoordinateResultSpec, DatasetRefResultSpec, DimensionRecordResultSpec +from .tree import DatasetSearch, InvalidQueryError, Predicate, QueryTree, make_identity_query_tree + + +@final +class Query(HomogeneousQueryBase): + """A method-chaining builder for butler queries. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. The + instance returned directly by the `Butler._query` entry point should be + constructed via `make_identity_query_tree`. + + Notes + ----- + `Query` objects should never be constructed directly by users; use + `Butler._query` instead. + + A `Query` object represents the first stage of query construction, in which + constraints and joins are defined (roughly corresponding to the WHERE and + FROM clauses in SQL). The various "results" objects represent the second + (and final) stage, where the columns returned are specified and any sorting + or integer slicing can be applied. Result objects are obtained from the + `data_ids`, `datasets`, and `dimension_records` methods. + + `Query` and query-result objects are always immutable (except for caching + information fetched from the database or server), so modifier methods + always return a new object without modifying the current one. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree): + # __init__ defined here because there are multiple base classes and + # not all define __init__ (and hence inherit object.__init__, which + # just ignores its args). Even if we just delegate to super(), it + # seems less fragile to make it explicit here. + super().__init__(driver, tree) + + @property + def constraint_dataset_types(self) -> Set[str]: + """The names of all dataset types joined into the query. + + The existence of datasets of these types constrains the data IDs of any + type of result. Fields for these dataset types are also usable in + 'where' expressions. + + Note that this includes only dataset type names, not `DatasetType` + instances; the `DatasetQueryResults` adapter returned by the `datasets` + method does include `DatasetType` instances, since it is in a better + position to track and respect any storage class override specified. + """ + return self._tree.datasets.keys() + + @property + def constraint_dimensions(self) -> DimensionGroup: + """Dimensions currently present in the query, either directly or + indirectly. + + This includes dimensions that are present in any joined subquery (such + as a dataset search, materialization, or data ID upload) or `where` + argument, as well as any required or implied dependency of those + dimensions. + """ + return self._tree.dimensions + + @property + def expression_factory(self) -> ExpressionFactory: + """A factory for column expressions using overloaded operators. + + Notes + ----- + Typically this attribute will be assigned to a single-character local + variable, and then its (dynamic) attributes can be used to obtain + references to columns that can be included in a query:: + + with butler._query() as query: + x = query.expression_factory + query = query.where( + x.instrument == "LSSTCam", + x.visit.day_obs > 20240701, + x.any(x.band == 'u', x.band == 'y'), + ) + + As shown above, the returned object also has an `any` method to create + combine expressions with logical OR (as well as `not_` and `all`, + though the latter is rarely necessary since `where` already combines + its arguments with AND). + + Proxies for fields associated with dataset types (``dataset_id``, + ``ingest_date``, ``run``, ``collection``, as well as ``timespan`` for + `~CollectionType.CALIBRATION` collection searches) can be obtained with + dict-like access instead:: + + with butler._query() as query: + query = query.order_by(x["raw"].ingest_date) + + Expression proxy objects that correspond to scalar columns overload the + standard comparison operators (``==``, ``!=``, ``<``, ``>``, ``<=``, + ``>=``) and provide `~ScalarExpressionProxy.in_range`, + `~ScalarExpressionProxy.in_iterable`, and + `~ScalarExpressionProxy.in_query` methods for membership tests. For + `order_by` contexts, they also have a `~ScalarExpressionProxy.desc` + property to indicate that the sort order for that expression should be + reversed. + + Proxy objects for region and timespan fields have an `overlaps` method, + and timespans also have `~TimespanProxy.begin` and `~TimespanProxy.end` + properties to access scalar expression proxies for the bounds. + + All proxy objects also have a `~ExpressionProxy.is_null` property. + + Literal values can be created by calling `ExpressionFactory.literal`, + but can almost always be created implicitly via overloaded operators + instead. + """ + return ExpressionFactory(self._driver.universe) + + def data_ids( + self, dimensions: DimensionGroup | Iterable[str] | str | None = None + ) -> DataCoordinateQueryResults: + """Return a result object that is a `DataCoordinate` iterable. + + Parameters + ---------- + dimensions : `DimensionGroup`, `str`, or \ + `~collections.abc.Iterable` [`str`], optional + The dimensions of the data IDs to yield, as either `DimensionGroup` + instances or `str` names. Will be automatically expanded to a + complete `DimensionGroup`. These dimensions do not need to match + the query's current `dimensions`. Default is + `constraint_dimensions`. + + Returns + ------- + data_ids : `DataCoordinateQueryResults` + Data IDs matching the given query parameters. These are guaranteed + to identify all dimensions (`DataCoordinate.hasFull` returns + `True`), but will not contain `DimensionRecord` objects + (`DataCoordinate.hasRecords` returns `False`). Call + `~DataCoordinateQueryResults.with_dimension_records` on the + returned object to include dimension records as well. + """ + tree = self._tree + if dimensions is None: + dimensions = self._tree.dimensions + else: + dimensions = self._driver.universe.conform(dimensions) + if not dimensions <= self._tree.dimensions: + tree = tree.join_dimensions(dimensions) + result_spec = DataCoordinateResultSpec(dimensions=dimensions, include_dimension_records=False) + return DataCoordinateQueryResults(self._driver, tree, result_spec) + + @overload + def datasets( + self, + dataset_type: str | DatasetType, + collections: str | Iterable[str] | None = None, + *, + find_first: bool = True, + ) -> SingleTypeDatasetQueryResults: ... # pragma: no cover + + @overload + def datasets( + self, + dataset_type: Iterable[str | DatasetType] | EllipsisType, + collections: str | Iterable[str] | None = None, + *, + find_first: bool = True, + ) -> DatasetQueryResults: ... # pragma: no cover + + def datasets( + self, + dataset_type: str | DatasetType | Iterable[str | DatasetType] | EllipsisType, + collections: str | Iterable[str] | None = None, + *, + find_first: bool = True, + ) -> DatasetQueryResults: + """Return a result object that is a `DatasetRef` iterable. + + Parameters + ---------- + dataset_type : `str`, `DatasetType`, \ + `~collections.abc.Iterable` [ `str` or `DatasetType` ], \ + or ``...`` + The dataset type or types to search for. Passing ``...`` searches + for all datasets in the given collections. + collections : `str` or `~collections.abc.Iterable` [ `str` ], optional + The collection or collections to search, in order. If not provided + or `None`, and the dataset has not already been joined into the + query, the default collection search path for this butler is used. + find_first : `bool`, optional + If `True` (default), for each result data ID, only yield one + `DatasetRef` of each `DatasetType`, from the first collection in + which a dataset of that dataset type appears (according to the + order of ``collections`` passed in). If `True`, ``collections`` + must not be ``...``. + + Returns + ------- + refs : `.queries.DatasetQueryResults` + Dataset references matching the given query criteria. Nested data + IDs are guaranteed to include values for all implied dimensions + (i.e. `DataCoordinate.hasFull` will return `True`), but will not + include dimension records (`DataCoordinate.hasRecords` will be + `False`) unless + `~.queries.DatasetQueryResults.with_dimension_records` is + called on the result object (which returns a new one). + + Raises + ------ + lsst.daf.butler.registry.DatasetTypeExpressionError + Raised when the ``dataset_type`` expression is invalid. + lsst.daf.butler.registry.NoDefaultCollectionError + Raised when ``collections`` is `None` and default butler + collections are not defined. + TypeError + Raised when the arguments are incompatible, such as when a + collection wildcard is passed when ``find_first`` is `True` + + Notes + ----- + When multiple dataset types are queried in a single call, the + results of this operation are equivalent to querying for each dataset + type separately in turn, and no information about the relationships + between datasets of different types is included. + """ + queries: dict[str, Query] = {} + if dataset_type is ...: + if collections is None: + collections = self._driver.get_default_collections() + else: + collections = tuple(ensure_iterable(collections)) + for _, summary in self._driver.resolve_collection_path(collections): + for dataset_type_name in summary.dataset_types.names: + queries[dataset_type_name] = self.join_dataset_search(dataset_type_name, collections) + else: + for arg in ensure_iterable(dataset_type): + dataset_type_name, query = self._join_dataset_search_impl(arg, collections) + queries[dataset_type_name] = query + + single_type_results: list[SingleTypeDatasetQueryResults] = [] + for dataset_type_name in sorted(queries): + query = queries[dataset_type_name] + dataset_search = query._tree.datasets[dataset_type_name] + if dataset_search.storage_class_name is None: + raise MissingDatasetTypeError( + f"No storage class provided for unregistered dataset type {dataset_type_name!r}. " + "Provide a complete DatasetType object instead of a string name to turn this error " + "into an empty result set." + ) + spec = DatasetRefResultSpec.model_construct( + dataset_type_name=dataset_type_name, + dimensions=dataset_search.dimensions, + storage_class_name=dataset_search.storage_class_name, + include_dimension_records=False, + find_first=find_first, + ) + single_type_results.append( + SingleTypeDatasetQueryResults(self._driver, tree=query._tree, spec=spec) + ) + if len(single_type_results) == 1: + return single_type_results[0] + else: + return ChainedDatasetQueryResults(tuple(single_type_results)) + + def dimension_records(self, element: str) -> DimensionRecordQueryResults: + """Return a result object that is a `DimensionRecord` iterable. + + Parameters + ---------- + element : `str` + The name of a dimension element to obtain records for. + + Returns + ------- + records : `.queries.DimensionRecordQueryResults` + Data IDs matching the given query parameters. + """ + tree = self._tree + if element not in tree.dimensions.elements: + tree = tree.join_dimensions(self._driver.universe[element].minimal_group) + result_spec = DimensionRecordResultSpec(element=self._driver.universe[element]) + return DimensionRecordQueryResults(self._driver, tree, result_spec) + + def materialize( + self, + *, + dimensions: Iterable[str] | DimensionGroup | None = None, + datasets: Iterable[str] | None = None, + ) -> Query: + """Execute the query, save its results to a temporary location, and + return a new query that represents fetching or joining against those + saved results. + + Parameters + ---------- + dimensions : `~collections.abc.Iterable` [ `str` ] or \ + `DimensionGroup`, optional + Dimensions to include in the temporary results. Default is to + include all dimensions in the query. + datasets : `~collections.abc.Iterable` [ `str` ], optional + Names of dataset types that should be included in the new query; + default is to include `constraint_dataset_types`. + + Returns + ------- + query : `Query` + A new query object whose that represents the materialized rows. + + Notes + ----- + Only dimension key columns and (at the discretion of the + implementation) certain dataset columns are actually materialized, + since at this stage we do not know which dataset or dimension record + fields are actually needed in result rows, and these can be joined back + in on the materialized dimension keys. But all constraints on those + dimension keys (including dataset existence) are applied to the + materialized rows. + """ + if datasets is None: + datasets = frozenset(self.constraint_dataset_types) + else: + datasets = frozenset(datasets) + if not (datasets <= self.constraint_dataset_types): + raise InvalidQueryError( + f"Dataset(s) {datasets - self.constraint_dataset_types} are present in the query." + ) + if dimensions is None: + dimensions = self._tree.dimensions + else: + dimensions = self._driver.universe.conform(dimensions) + key = self._driver.materialize(self._tree, dimensions, datasets) + tree = make_identity_query_tree(self._driver.universe).join_materialization( + key, dimensions=dimensions + ) + for dataset_type_name in datasets: + dataset_search = self._tree.datasets[dataset_type_name] + if not (dataset_search.dimensions <= tree.dimensions): + raise InvalidQueryError( + f"Materialization-backed query has dimensions {tree.dimensions}, which do not " + f"cover the dimensions {dataset_search.dimensions} of dataset {dataset_type_name!r}. " + "Expand the dimensions or drop this dataset type in the arguments to materialize to " + "avoid this error." + ) + tree = tree.join_dataset(dataset_type_name, self._tree.datasets[dataset_type_name]) + return Query(self._driver, tree) + + def join_dataset_search( + self, + dataset_type: str | DatasetType, + collections: Iterable[str] | None = None, + dimensions: DimensionGroup | None = None, + ) -> Query: + """Return a new query with a search for a dataset joined in. + + Parameters + ---------- + dataset_type : `str` or `DatasetType` + Dataset type or name. May not refer to a dataset component. + collections : `~collections.abc.Iterable` [ `str` ], optional + Iterable of collections to search. Order is preserved, but will + not matter if the dataset search is only used as a constraint on + dimensions or if ``find_first=False`` when requesting results. If + not present or `None`, the default collection search path will be + used. + dimensions : `DimensionGroup`, optional + The dimensions to assume for the dataset type if it is not + registered, or to check against if it is registered. When the + dataset is not registered and this is not provided, + `MissingDatasetTypeError` is raised, since we cannot construct a + query without knowing the dataset's dimensions. Providing this + argument causes the returned query to instead return no rows (as it + does when the dataset type is registered but no matching datasets + are found). + + Returns + ------- + query : `Query` + A new query object with dataset columns available and rows + restricted to those consistent with the found data IDs. + + Raises + ------ + DatasetTypeError + Raised if the dimensions were provided but they do not match the + registered dataset type. + MissingDatasetTypeError + Raised if the dimensions were not provided and the dataset type was + not registered. + + Notes + ----- + This method may require communication with the server unless the + dataset type and collections have already been referenced by the same + query context. + """ + _, query = self._join_dataset_search_impl(dataset_type, collections, dimensions) + return query + + def join_data_coordinates(self, iterable: Iterable[DataCoordinate]) -> Query: + """Return a new query that joins in an explicit table of data IDs. + + Parameters + ---------- + iterable : `~collections.abc.Iterable` [ `DataCoordinate` ] + Iterable of `DataCoordinate`. All items must have the same + dimensions. Must have at least one item. + + Returns + ------- + query : `Query` + A new query object with the data IDs joined in. + """ + rows: set[tuple[DataIdValue, ...]] = set() + dimensions: DimensionGroup | None = None + for data_coordinate in iterable: + if dimensions is None: + dimensions = data_coordinate.dimensions + elif dimensions != data_coordinate.dimensions: + raise InvalidQueryError( + f"Inconsistent dimensions: {dimensions} != {data_coordinate.dimensions}." + ) + rows.add(data_coordinate.required_values) + if dimensions is None: + raise InvalidQueryError("Cannot upload an empty data coordinate set.") + key = self._driver.upload_data_coordinates(dimensions, rows) + return Query( + tree=self._tree.join_data_coordinate_upload(dimensions=dimensions, key=key), driver=self._driver + ) + + def join_dimensions(self, dimensions: Iterable[str] | DimensionGroup) -> Query: + """Return a new query that joins the logical tables for additional + dimensions. + + Parameters + ---------- + dimensions : `~collections.abc.Iterable` [ `str` ] or `DimensionGroup` + Names of dimensions to join in. + + Returns + ------- + query : `Query` + A new query object with the dimensions joined in. + + Notes + ----- + Dimensions are automatically joined in whenever needed, so this method + should rarely need to be called directly. + """ + dimensions = self._driver.universe.conform(dimensions) + return Query(tree=self._tree.join_dimensions(dimensions), driver=self._driver) + + def where( + self, + *args: str | Predicate | DataId, + bind: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> Query: + """Return a query with a boolean-expression filter on its rows. + + Parameters + ---------- + *args + Constraints to apply, combined with logical AND. Arguments may be + `str` expressions to parse, `Predicate` objects (these are + typically constructed via `expression_factory`) or data IDs. + bind : `~collections.abc.Mapping` + Mapping from string identifier appearing in a string expression to + a literal value that should be substituted for it. This is + recommended instead of embedding literals directly into the + expression, especially for strings, timespans, or other types where + quoting or formatting is nontrivial. + **kwargs + Data ID key value pairs that extend and override any present in + ``*args``. + + Returns + ------- + query : `Query` + A new query object with the given row filters (as well as any + already present in ``self``). All row filters are combined with + logical AND. + + Notes + ----- + If an expression references a dimension or dimension element that is + not already present in the query, it will be joined in, but dataset + searches must already be joined into a query in order to reference + their fields in expressions. + + Data ID values are not checked for consistency; they are extracted from + ``args`` and then ``kwargs`` and combined, with later values overriding + earlier ones. + """ + return Query( + tree=self._tree.where( + convert_where_args(self.dimensions, self.constraint_dataset_types, *args, bind=bind, **kwargs) + ), + driver=self._driver, + ) + + def _join_dataset_search_impl( + self, + dataset_type: str | DatasetType, + collections: Iterable[str] | None = None, + dimensions: DimensionGroup | None = None, + ) -> tuple[str, Query]: + """Implement `join_dataset_search`, and also return the dataset type + name. + """ + # In this method we need the dimensions of the dataset type, but we + # don't necessarily need the storage class, since the dataset may only + # be used as an existence constraint. But we also want to remember the + # storage class if it's passed in, so users don't get frustrated having + # to pass it twice if they do want DatasetRefs back. + storage_class_name: str | None = None + # Handle DatasetType vs. str arg. + if isinstance(dataset_type, DatasetType): + dataset_type_name = dataset_type.name + if dimensions is not None: + raise TypeError("Cannot provide a full DatasetType object and separate dimensions.") + dimensions = dataset_type.dimensions.as_group() + storage_class_name = dataset_type.storageClass_name + elif isinstance(dataset_type, str): + dataset_type_name = dataset_type + else: + raise TypeError(f"Invalid dataset type argument {dataset_type!r}.") + # See if this dataset has already been joined into the query. + if existing_search := self._tree.datasets.get(dataset_type_name): + if collections is None: + collections = existing_search.collections + else: + collections = tuple(ensure_iterable(collections)) + if collections != existing_search.collections: + raise InvalidQueryError( + f"Dataset type {dataset_type_name!r} was already joined into this " + "query with a different collection search path (previously " + f"[{', '.join(existing_search.collections)}], now [{', '.join(collections)}])." + ) + if dimensions is None: + dimensions = existing_search.dimensions + elif dimensions != existing_search.dimensions: + raise DatasetTypeError( + f"Given dimensions {dimensions} for dataset type {dataset_type_name!r} do not match the " + f"previously-joined dimensions {existing_search.dimensions}." + ) + if storage_class_name is None or storage_class_name == existing_search.storage_class_name: + # Nothing to do; this dataset has already been joined in with + # the parameters we want. We don't need to check against the + # registered dataset type since that will have been done the + # first time we joined this dataset type in. + return dataset_type_name, self + else: + if collections is None: + collections = self._driver.get_default_collections() + collections = tuple(ensure_iterable(collections)) + # See if the dataset type is registered, to look up and/or check + # dimensions, and get a storage class if there isn't one already. + try: + resolved_dataset_type = self._driver.get_dataset_type(dataset_type_name) + resolved_dimensions = resolved_dataset_type.dimensions.as_group() + if storage_class_name is None: + storage_class_name = resolved_dataset_type.storageClass_name + except MissingDatasetTypeError: + if dimensions is None: + raise + resolved_dimensions = dimensions + else: + if dimensions is not None and dimensions != resolved_dimensions: + raise DatasetTypeError( + f"Given dimensions {dimensions} for dataset type {dataset_type_name!r} do not match the " + f"registered dimensions {resolved_dimensions}." + ) + if ( + storage_class_name is not None + and storage_class_name != resolved_dataset_type.storageClass_name + ): + if not ( + StorageClassFactory() + .getStorageClass(storage_class_name) + .can_convert(resolved_dataset_type.storageClass) + ): + raise DatasetTypeError( + f"Given storage class {storage_class_name!r} for {dataset_type_name!r} is not " + f"compatible with repository storage class {resolved_dataset_type.storageClass_name}." + ) + # We do not check the storage class for consistency with the registered + # storage class at this point, because it's not going to be used for + # anything yet other than a default that can still be overridden. + dataset_search = DatasetSearch.model_construct( + collections=collections, + dimensions=resolved_dimensions, + storage_class_name=storage_class_name, + ) + return dataset_type_name, Query( + self._driver, self._tree.join_dataset(dataset_type_name, dataset_search) + ) diff --git a/python/lsst/daf/butler/queries/convert_args.py b/python/lsst/daf/butler/queries/convert_args.py new file mode 100644 index 0000000000..07ab3b1e82 --- /dev/null +++ b/python/lsst/daf/butler/queries/convert_args.py @@ -0,0 +1,263 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "convert_where_args", + "convert_order_by_args", +) + +import itertools +from collections.abc import Mapping, Set +from typing import Any, cast + +from ..dimensions import DataCoordinate, DataId, Dimension, DimensionGroup +from .expression_factory import ExpressionFactory, ExpressionProxy +from .tree import ( + DATASET_FIELD_NAMES, + ColumnExpression, + DatasetFieldName, + DatasetFieldReference, + DimensionFieldReference, + DimensionKeyReference, + InvalidQueryError, + OrderExpression, + Predicate, + Reversed, + UnaryExpression, + make_column_literal, + validate_order_expression, +) + + +def convert_where_args( + dimensions: DimensionGroup, + datasets: Set[str], + *args: str | Predicate | DataId, + bind: Mapping[str, Any] | None = None, + **kwargs: Any, +) -> Predicate: + """Convert ``where`` arguments to a sequence of column expressions. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions already present in the query this filter is being applied + to. Returned predicates may reference dimensions outside this set. + datasets : `~collections.abc.Set` [ `str` ] + Dataset types already present in the query this filter is being applied + to. Returned predicates may reference datasets outside this set; this + may be an error at a higher level, but it is not necessarily checked + here. + *args : `str`, `Predicate`, `DataCoordinate`, or `~collections.abc.Mapping` + Expressions to convert into predicates. + bind : `~collections.abc.Mapping`, optional + Mapping from identifier to literal value used when parsing string + expressions. + **kwargs : `object` + Additional data ID key-value pairs. + + Returns + ------- + predicate : `Predicate` + Standardized predicate object. + + Notes + ----- + Data ID values are not checked for consistency; they are extracted from + args and then kwargs and combined, with later extractions taking + precedence. + """ + result = Predicate.from_bool(True) + data_id_dict: dict[str, Any] = {} + for arg in args: + match arg: + case str(): + raise NotImplementedError("TODO: plug in registry.queries.expressions.parser") + case Predicate(): + result = result.logical_and(arg) + case DataCoordinate(): + data_id_dict.update(arg.mapping) + case _: + data_id_dict.update(arg) + data_id_dict.update(kwargs) + for k, v in data_id_dict.items(): + result = result.logical_and( + Predicate.compare( + DimensionKeyReference.model_construct(dimension=dimensions.universe.dimensions[k]), + "==", + make_column_literal(v), + ) + ) + return result + + +def convert_order_by_args( + dimensions: DimensionGroup, datasets: Set[str], *args: str | OrderExpression | ExpressionProxy +) -> tuple[OrderExpression, ...]: + """Convert ``order_by`` arguments to a sequence of column expressions. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions already present in the query whose rows are being sorted. + Returned expressions may reference dimensions outside this set; this + may be an error at a higher level, but it is not necessarily checked + here. + datasets : `~collections.abc.Set` [ `str` ] + Dataset types already present in the query whose rows are being sorted. + Returned expressions may reference datasets outside this set; this may + be an error at a higher level, but it is not necessarily checked here. + *args : `OrderExpression`, `str`, or `ExpressionObject` + Expression or column names to sort by. + + Returns + ------- + expressions : `tuple` [ `OrderExpression`, ... ] + Standardized expression objects. + """ + result: list[OrderExpression] = [] + for arg in args: + match arg: + case str(): + reverse = False + if arg.startswith("-"): + reverse = True + arg = arg[1:] + arg = interpret_identifier(dimensions, datasets, arg, {}) + if reverse: + arg = Reversed(operand=arg) + case ExpressionProxy(): + arg = ExpressionFactory.unwrap(arg) + if not hasattr(arg, "expression_type"): + raise TypeError(f"Unrecognized order-by argument: {arg!r}.") + result.append(validate_order_expression(arg)) + return tuple(result) + + +def interpret_identifier( + dimensions: DimensionGroup, datasets: Set[str], identifier: str, bind: Mapping[str, Any] +) -> ColumnExpression: + """Associate an identifier in a ``where`` or ``order_by`` expression with + a query column or bind literal. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions already present in the query this filter is being applied + to. Returned expressions may reference dimensions outside this set. + datasets : `~collections.abc.Set` [ `str` ] + Dataset types already present in the query this filter is being applied + to. Returned expressions may reference datasets outside this set. + identifier : `str` + String identifier to process. + bind : `~collections.abc.Mapping` [ `str`, `object` ] + Dictionary of bind literals to match identifiers against first. + + Returns + ------- + expression : `ColumnExpression` + Column expression corresponding to the identifier. + """ + if identifier in bind: + return make_column_literal(bind[identifier]) + terms = identifier.split(".") + match len(terms): + case 1: + if identifier in dimensions.universe.dimensions: + return DimensionKeyReference.model_construct( + dimension=dimensions.universe.dimensions[identifier] + ) + # This is an unqualified reference to a field of a dimension + # element or datasets; this is okay if it's unambiguous. + element_matches: set[str] = set() + for element_name in dimensions.elements: + element = dimensions.universe[element_name] + if identifier in element.schema.names: + element_matches.add(element_name) + if identifier in DATASET_FIELD_NAMES: + dataset_matches = set(datasets) + else: + dataset_matches = set() + if len(element_matches) + len(dataset_matches) > 1: + match_str = ", ".join( + f"'{x}.{identifier}'" for x in sorted(itertools.chain(element_matches, dataset_matches)) + ) + raise InvalidQueryError( + f"Ambiguous identifier {identifier!r} matches multiple fields: {match_str}." + ) + elif element_matches: + element = dimensions.universe[element_matches.pop()] + return DimensionFieldReference.model_construct(element=element, field=identifier) + elif dataset_matches: + return DatasetFieldReference.model_construct( + dataset_type=dataset_matches.pop(), field=cast(DatasetFieldName, identifier) + ) + case 2: + first, second = terms + if first in dimensions.universe.elements.names: + element = dimensions.universe[first] + if second in element.schema.dimensions.names: + if isinstance(element, Dimension) and second == element.primary_key.name: + # Identifier is something like "visit.id" which we want + # to interpret the same way as just "visit". + return DimensionKeyReference.model_construct(dimension=element) + else: + # Identifier is something like "visit.instrument", + # which we want to interpret the same way as just + # "instrument". + dimension = dimensions.universe.dimensions[second] + return DimensionKeyReference.model_construct(dimension=dimension) + elif second in element.schema.remainder.names: + return DimensionFieldReference.model_construct(element=element, field=second) + else: + raise InvalidQueryError(f"Unrecognized field {second!r} for {first}.") + elif second in DATASET_FIELD_NAMES: + # We just assume the dataset type is okay; it's the job of + # higher-level code to complain otherwise. + return DatasetFieldReference.model_construct( + dataset_type=first, field=cast(DatasetFieldName, second) + ) + if first == "timespan": + base = interpret_identifier(dimensions, datasets, "timespan", bind) + if second == "begin": + return UnaryExpression(operand=base, operator="begin_of") + if second == "end": + return UnaryExpression(operand=base, operator="end_of") + elif first in datasets: + raise InvalidQueryError( + f"Identifier {identifier!r} references dataset type {first!r} but field " + f"{second!r} is not valid for datasets." + ) + case 3: + base = interpret_identifier(dimensions, datasets, ".".join(terms[:2]), bind) + if terms[2] == "begin": + return UnaryExpression(operand=base, operator="begin_of") + if terms[2] == "end": + return UnaryExpression(operand=base, operator="end_of") + raise InvalidQueryError(f"Unrecognized identifier {identifier!r}.") diff --git a/python/lsst/daf/butler/queries/driver.py b/python/lsst/daf/butler/queries/driver.py new file mode 100644 index 0000000000..9abb710a5f --- /dev/null +++ b/python/lsst/daf/butler/queries/driver.py @@ -0,0 +1,458 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "QueryDriver", + "PageKey", + "ResultPage", + "DataCoordinateResultPage", + "DimensionRecordResultPage", + "DatasetRefResultPage", + "GeneralResultPage", +) + +import dataclasses +import uuid +from abc import abstractmethod +from collections.abc import Iterable, Sequence +from contextlib import AbstractContextManager +from typing import Any, TypeAlias, Union, overload + +from .._dataset_ref import DatasetRef +from .._dataset_type import DatasetType +from ..dimensions import ( + DataCoordinate, + DataIdValue, + DimensionGroup, + DimensionRecord, + DimensionRecordSet, + DimensionRecordTable, + DimensionUniverse, +) +from ..registry import CollectionSummary +from ..registry.interfaces import CollectionRecord +from .result_specs import ( + DataCoordinateResultSpec, + DatasetRefResultSpec, + DimensionRecordResultSpec, + GeneralResultSpec, + ResultSpec, +) +from .tree import DataCoordinateUploadKey, MaterializationKey, QueryTree + +PageKey: TypeAlias = uuid.UUID + + +# The Page types below could become Pydantic models instead of dataclasses if +# that makes them more directly usable by RemoteButler (at least once we have +# Pydantic-friendly containers for all of them). We may want to add a +# discriminator annotation to the ResultPage union if we do that. + + +@dataclasses.dataclass +class DataCoordinateResultPage: + """A single page of results from a data coordinate query.""" + + spec: DataCoordinateResultSpec + next_key: PageKey | None + + # TODO: On DM-41114 this will become a custom container that normalizes out + # attached DimensionRecords and is Pydantic-friendly. + rows: list[DataCoordinate] + + +@dataclasses.dataclass +class DimensionRecordResultPage: + """A single page of results from a dimension record query.""" + + spec: DimensionRecordResultSpec + next_key: PageKey | None + rows: Iterable[DimensionRecord] + + def as_table(self) -> DimensionRecordTable: + if isinstance(self.rows, DimensionRecordTable): + return self.rows + else: + return DimensionRecordTable(self.spec.element, self.rows) + + def as_set(self) -> DimensionRecordSet: + if isinstance(self.rows, DimensionRecordSet): + return self.rows + else: + return DimensionRecordSet(self.spec.element, self.rows) + + +@dataclasses.dataclass +class DatasetRefResultPage: + """A single page of results from a dataset query.""" + + spec: DatasetRefResultSpec + next_key: PageKey | None + + # TODO: On DM-41115 this will become a custom container that normalizes out + # attached DimensionRecords and is Pydantic-friendly. + rows: list[DatasetRef] + + +@dataclasses.dataclass +class GeneralResultPage: + """A single page of results from a general query.""" + + spec: GeneralResultSpec + next_key: PageKey | None + + # Raw tabular data, with columns in the same order as spec.columns. + rows: list[tuple[Any, ...]] + + +ResultPage: TypeAlias = Union[ + DataCoordinateResultPage, DimensionRecordResultPage, DatasetRefResultPage, GeneralResultPage +] + + +class QueryDriver(AbstractContextManager[None]): + """Base class for the implementation object inside `Query2` objects + that is specialized for DirectButler vs. RemoteButler. + + Notes + ----- + Implementations should be context managers. This allows them to manage the + lifetime of server-side state, such as: + + - a SQL transaction, when necessary (DirectButler); + - SQL cursors for queries that were not fully iterated over (DirectButler); + - temporary database tables (DirectButler); + - result-page Parquet files that were never fetched (RemoteButler); + - uploaded Parquet files used to fill temporary database tables + (RemoteButler); + - cached content needed to construct query trees, like collection summaries + (potentially all Butlers). + + When possible, these sorts of things should be cleaned up earlier when they + are no longer needed, and the Butler server will still have to guard + against the context manager's ``__exit__`` signal never reaching it, but a + context manager will take care of these much more often than relying on + garbage collection and ``__del__`` would. + """ + + @property + @abstractmethod + def universe(self) -> DimensionUniverse: + """Object that defines all dimensions.""" + raise NotImplementedError() + + @overload + def execute(self, result_spec: DataCoordinateResultSpec, tree: QueryTree) -> DataCoordinateResultPage: ... + + @overload + def execute( + self, result_spec: DimensionRecordResultSpec, tree: QueryTree + ) -> DimensionRecordResultPage: ... + + @overload + def execute(self, result_spec: DatasetRefResultSpec, tree: QueryTree) -> DatasetRefResultPage: ... + + @overload + def execute(self, result_spec: GeneralResultSpec, tree: QueryTree) -> GeneralResultPage: ... + + @abstractmethod + def execute(self, result_spec: ResultSpec, tree: QueryTree) -> ResultPage: + """Execute a query and return the first result page. + + Parameters + ---------- + result_spec : `ResultSpec` + The kind of results the user wants from the query. This can affect + the actual query (i.e. SQL and Python postprocessing) that is run, + e.g. by changing what is in the SQL SELECT clause and even what + tables are joined in, but it never changes the number or order of + result rows. + tree : `QueryTree` + Query tree to evaluate. + + Returns + ------- + first_page : `ResultPage` + A page whose type corresponds to the type of ``result_spec``, with + at least the initial rows from the query. This should have an + empty ``rows`` attribute if the query returned no results, and a + ``next_key`` attribute that is not `None` if there were more + results than could be returned in a single page. + """ + raise NotImplementedError() + + @overload + def fetch_next_page( + self, result_spec: DataCoordinateResultSpec, key: PageKey + ) -> DataCoordinateResultPage: ... + + @overload + def fetch_next_page( + self, result_spec: DimensionRecordResultSpec, key: PageKey + ) -> DimensionRecordResultPage: ... + + @overload + def fetch_next_page(self, result_spec: DatasetRefResultSpec, key: PageKey) -> DatasetRefResultPage: ... + + @overload + def fetch_next_page(self, result_spec: GeneralResultSpec, key: PageKey) -> GeneralResultPage: ... + + @abstractmethod + def fetch_next_page(self, result_spec: ResultSpec, key: PageKey) -> ResultPage: + """Fetch the next page of results from an already-executed query. + + Parameters + ---------- + result_spec : `ResultSpec` + The kind of results the user wants from the query. This must be + identical to the ``result_spec`` passed to `execute`, but + implementations are not *required* to check this. + key : `PageKey` + Key included in the previous page from this query. This key may + become unusable or even be reused after this call. + + Returns + ------- + next_page : `ResultPage` + The next page of query results. + """ + # We can put off dealing with pagination initially by just making an + # implementation of this method raise. + # + # In RemoteButler I expect this to work by having the call to execute + # continue to write Parquet files (or whatever) to some location until + # its cursor is exhausted, and then delete those files as they are + # fetched (or, failing that, when receiving a signal from + # ``__exit__``). + # + # In DirectButler I expect to have a dict[PageKey, Cursor], fetch a + # blocks of rows from it, and just reuse the page key for the next page + # until the cursor is exactly. + raise NotImplementedError() + + @abstractmethod + def materialize( + self, + tree: QueryTree, + dimensions: DimensionGroup, + datasets: frozenset[str], + ) -> MaterializationKey: + """Execute a query tree, saving results to temporary storage for use + in later queries. + + Parameters + ---------- + tree : `QueryTree` + Query tree to evaluate. + dimensions : `DimensionGroup` + Dimensions whose key columns should be preserved. + datasets : `frozenset` [ `str` ] + Names of dataset types whose ID columns may be materialized. It + is implementation-defined whether they actually are. + + Returns + ------- + key : `MaterializationKey` + Unique identifier for the result rows that allows them to be + referenced in a `QueryTree`. + """ + raise NotImplementedError() + + @abstractmethod + def upload_data_coordinates( + self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]] + ) -> DataCoordinateUploadKey: + """Upload a table of data coordinates for use in later queries. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions of the data coordinates. + rows : `Iterable` [ `tuple` ] + Tuples of data coordinate values, covering just the "required" + subset of ``dimensions``. + + Returns + ------- + key + Unique identifier for the upload that allows it to be referenced in + a `QueryTree`. + """ + raise NotImplementedError() + + @abstractmethod + def count( + self, + tree: QueryTree, + result_spec: ResultSpec, + *, + exact: bool, + discard: bool, + ) -> int: + """Return the number of rows a query would return. + + Parameters + ---------- + tree : `QueryTree` + Query tree to evaluate. + result_spec : `ResultSpec` + The kind of results the user wants to count. + exact : `bool`, optional + If `True`, run the full query and perform post-query filtering if + needed to account for that filtering in the count. If `False`, the + result may be an upper bound. + discard : `bool`, optional + If `True`, compute the exact count even if it would require running + the full query and then throwing away the result rows after + counting them. If `False`, this is an error, as the user would + usually be better off executing the query first to fetch its rows + into a new query (or passing ``exact=False``). Ignored if + ``exact=False``. + """ + raise NotImplementedError() + + @abstractmethod + def any(self, tree: QueryTree, *, execute: bool, exact: bool) -> bool: + """Test whether the query would return any rows. + + Parameters + ---------- + tree : `QueryTree` + Query tree to evaluate. + execute : `bool`, optional + If `True`, execute at least a ``LIMIT 1`` query if it cannot be + determined prior to execution that the query would return no rows. + exact : `bool`, optional + If `True`, run the full query and perform post-query filtering if + needed, until at least one result row is found. If `False`, the + returned result does not account for post-query filtering, and + hence may be `True` even when all result rows would be filtered + out. + + Returns + ------- + any : `bool` + `True` if the query would (or might, depending on arguments) yield + result rows. `False` if it definitely would not. + """ + raise NotImplementedError() + + @abstractmethod + def explain_no_results(self, tree: QueryTree, execute: bool) -> Iterable[str]: + """Return human-readable messages that may help explain why the query + yields no results. + + Parameters + ---------- + tree : `QueryTree` + Query tree to evaluate. + execute : `bool`, optional + If `True` (default) execute simplified versions (e.g. ``LIMIT 1``) + of aspects of the tree to more precisely determine where rows were + filtered out. + + Returns + ------- + messages : `~collections.abc.Iterable` [ `str` ] + String messages that describe reasons the query might not yield any + results. + """ + raise NotImplementedError() + + @abstractmethod + def get_default_collections(self) -> tuple[str, ...]: + """Return the default collection search path. + + Returns + ------- + collections : `tuple` [ `str`, ... ] + The default collection search path as a tuple of `str`. + + Raises + ------ + NoDefaultCollectionError + Raised if there are no default collections. + """ + raise NotImplementedError() + + @abstractmethod + def resolve_collection_path( + self, collections: Sequence[str] + ) -> list[tuple[CollectionRecord, CollectionSummary]]: + """Process a collection search path argument into a `list` of + collection records and summaries. + + Parameters + ---------- + collections : `~collections.abc.Sequence` [ `str` ] + The collection or collections to search. + + Returns + ------- + collection_info : `list` [ `tuple` [ `CollectionRecord`, \ + `CollectionSummary` ] ] + A `list` of pairs of `CollectionRecord` and `CollectionSummary` + that flattens out all `~CollectionType.CHAINED` collections into + their children while maintaining the same order and avoiding + duplicates. + + Raises + ------ + MissingCollectionError + Raised if any collection in ``collections`` does not exist. + + Notes + ----- + Implementations are generally expected to cache the collection records + and summaries they obtain (including the records for + `~CollectionType.CHAINED` collections that are not returned) in order + to optimize multiple calls with collections in common. + """ + raise NotImplementedError() + + @abstractmethod + def get_dataset_type(self, name: str) -> DatasetType: + """Return the dimensions for a dataset type. + + Parameters + ---------- + name : `str` + Name of the dataset type. + + Returns + ------- + dataset_type : `DatasetType` + Dimensions of the dataset type. + + Raises + ------ + MissingDatasetTypeError + Raised if the dataset type is not registered. + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/queries/expression_factory.py b/python/lsst/daf/butler/queries/expression_factory.py new file mode 100644 index 0000000000..e995f05ff7 --- /dev/null +++ b/python/lsst/daf/butler/queries/expression_factory.py @@ -0,0 +1,502 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("ExpressionFactory", "ExpressionProxy", "ScalarExpressionProxy", "TimespanProxy", "RegionProxy") + +from collections.abc import Iterable +from typing import TYPE_CHECKING, cast + +from lsst.sphgeom import Region + +from ..dimensions import Dimension, DimensionElement, DimensionUniverse +from . import tree + +if TYPE_CHECKING: + from .._timespan import Timespan + from ._query import Query + +# This module uses ExpressionProxy and its subclasses to wrap ColumnExpression, +# but it just returns OrderExpression and Predicate objects directly, because +# we don't need to overload any operators or define any methods on those. + + +class ExpressionProxy: + """A wrapper for column expressions that overloads comparison operators + to return new expression proxies. + + Parameters + ---------- + expression : `tree.ColumnExpression` + Underlying expression object. + """ + + def __init__(self, expression: tree.ColumnExpression): + self._expression = expression + + def __repr__(self) -> str: + return str(self._expression) + + @property + def is_null(self) -> tree.Predicate: + """A boolean expression that tests whether this expression is NULL.""" + return tree.Predicate.is_null(self._expression) + + @staticmethod + def _make_expression(other: object) -> tree.ColumnExpression: + if isinstance(other, ExpressionProxy): + return other._expression + else: + return tree.make_column_literal(other) + + def _make_comparison(self, other: object, operator: tree.ComparisonOperator) -> tree.Predicate: + return tree.Predicate.compare(a=self._expression, b=self._make_expression(other), operator=operator) + + +class ScalarExpressionProxy(ExpressionProxy): + """An `ExpressionProxy` specialized for simple single-value columns.""" + + @property + def desc(self) -> tree.Reversed: + """An ordering expression that indicates that the sort on this + expression should be reversed. + """ + return tree.Reversed(operand=self._expression) + + def __eq__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, "==") + + def __ne__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, "!=") + + def __lt__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, "<") + + def __le__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, "<=") + + def __gt__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, ">") + + def __ge__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, ">=") + + def __neg__(self) -> ScalarExpressionProxy: + return ScalarExpressionProxy(tree.UnaryExpression(operand=self._expression, operator="-")) + + def __add__(self, other: object) -> ScalarExpressionProxy: + return ScalarExpressionProxy( + tree.BinaryExpression(a=self._expression, b=self._make_expression(other), operator="+") + ) + + def __radd__(self, other: object) -> ScalarExpressionProxy: + return ScalarExpressionProxy( + tree.BinaryExpression(a=self._make_expression(other), b=self._expression, operator="+") + ) + + def __sub__(self, other: object) -> ScalarExpressionProxy: + return ScalarExpressionProxy( + tree.BinaryExpression(a=self._expression, b=self._make_expression(other), operator="-") + ) + + def __rsub__(self, other: object) -> ScalarExpressionProxy: + return ScalarExpressionProxy( + tree.BinaryExpression(a=self._make_expression(other), b=self._expression, operator="-") + ) + + def __mul__(self, other: object) -> ScalarExpressionProxy: + return ScalarExpressionProxy( + tree.BinaryExpression(a=self._expression, b=self._make_expression(other), operator="*") + ) + + def __rmul__(self, other: object) -> ScalarExpressionProxy: + return ScalarExpressionProxy( + tree.BinaryExpression(a=self._make_expression(other), b=self._expression, operator="*") + ) + + def __truediv__(self, other: object) -> ScalarExpressionProxy: + return ScalarExpressionProxy( + tree.BinaryExpression(a=self._expression, b=self._make_expression(other), operator="/") + ) + + def __rtruediv__(self, other: object) -> ScalarExpressionProxy: + return ScalarExpressionProxy( + tree.BinaryExpression(a=self._make_expression(other), b=self._expression, operator="/") + ) + + def __mod__(self, other: object) -> ScalarExpressionProxy: + return ScalarExpressionProxy( + tree.BinaryExpression(a=self._expression, b=self._make_expression(other), operator="%") + ) + + def __rmod__(self, other: object) -> ScalarExpressionProxy: + return ScalarExpressionProxy( + tree.BinaryExpression(a=self._make_expression(other), b=self._expression, operator="%") + ) + + def in_range(self, start: int = 0, stop: int | None = None, step: int = 1) -> tree.Predicate: + """Return a boolean expression that tests whether this expression is + within a literal integer range. + + Parameters + ---------- + start : `int`, optional + Lower bound (inclusive) for the slice. + stop : `int` or `None`, optional + Upper bound (exclusive) for the slice, or `None` for no bound. + step : `int`, optional + Spacing between integers in the range. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return tree.Predicate.in_range(self._expression, start=start, stop=stop, step=step) + + def in_iterable(self, others: Iterable) -> tree.Predicate: + """Return a boolean expression that tests whether this expression + evaluates to a value that is in an iterable of other expressions. + + Parameters + ---------- + others : `collections.abc.Iterable` + An iterable of `ExpressionProxy` or values to be interpreted as + literals. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return tree.Predicate.in_container(self._expression, [self._make_expression(item) for item in others]) + + def in_query(self, column: ExpressionProxy, query: Query) -> tree.Predicate: + """Return a boolean expression that test whether this expression + evaluates to a value that is in a single-column selection from another + query. + + Parameters + ---------- + column : `ExpressionProxy` + Proxy for the column to extract from ``query``. + query : `Query` + Query to select from. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return tree.Predicate.in_query(self._expression, column._expression, query._tree) + + +class TimespanProxy(ExpressionProxy): + """An `ExpressionProxy` specialized for timespan columns and literals.""" + + @property + def begin(self) -> ExpressionProxy: + """An expression representing the lower bound (inclusive).""" + return ExpressionProxy(tree.UnaryExpression(operand=self._expression, operator="begin_of")) + + @property + def end(self) -> ExpressionProxy: + """An expression representing the upper bound (exclusive).""" + return ExpressionProxy(tree.UnaryExpression(operand=self._expression, operator="end_of")) + + def overlaps(self, other: TimespanProxy | Timespan) -> tree.Predicate: + """Return a boolean expression representing an overlap test between + this timespan and another. + + Parameters + ---------- + other : `TimespanProxy` or `Timespan` + Expression or literal to compare to. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return self._make_comparison(other, "overlaps") + + +class RegionProxy(ExpressionProxy): + """An `ExpressionProxy` specialized for region columns and literals.""" + + def overlaps(self, other: RegionProxy | Region) -> tree.Predicate: + """Return a boolean expression representing an overlap test between + this region and another. + + Parameters + ---------- + other : `RegionProxy` or `Region` + Expression or literal to compare to. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return self._make_comparison(other, "overlaps") + + +class DimensionElementProxy: + """An expression-creation proxy for a dimension element logical table. + + Parameters + ---------- + element : `DimensionElement` + Element this object wraps. + + Notes + ----- + The (dynamic) attributes of this object are expression proxies for the + non-dimension fields of the element's records. + """ + + def __init__(self, element: DimensionElement): + self._element = element + + def __repr__(self) -> str: + return self._element.name + + def __getattr__(self, field: str) -> ExpressionProxy: + if field in self._element.schema.dimensions.names: + return DimensionProxy(self._element.dimensions[field]) + try: + expression = tree.DimensionFieldReference(element=self._element, field=field) + except tree.InvalidQueryError: + raise AttributeError(field) + match expression.column_type: + case "region": + return RegionProxy(expression) + case "timespan": + return TimespanProxy(expression) + return ScalarExpressionProxy(expression) + + def __dir__(self) -> list[str]: + result = list(super().__dir__()) + result.extend(self._element.schema.names) + return result + + +class DimensionProxy(ScalarExpressionProxy, DimensionElementProxy): + """An expression-creation proxy for a dimension logical table. + + Parameters + ---------- + dimension : `Dimension` + Dimension this object wraps. + + Notes + ----- + This class combines record-field attribute access from `DimensionElement` + proxy with direct interpretation as a dimension key column via + `ScalarExpressionProxy`. For example:: + + x = query.expression_factory + query.where( + x.detector.purpose == "SCIENCE", # field access + x.detector > 100, # direct usage as an expression + ) + """ + + def __init__(self, dimension: Dimension): + ScalarExpressionProxy.__init__(self, tree.DimensionKeyReference(dimension=dimension)) + DimensionElementProxy.__init__(self, dimension) + + def __getattr__(self, field: str) -> ExpressionProxy: + if field == self._element.primary_key.name: + return self + return super().__getattr__(field) + + _element: Dimension + + +class DatasetTypeProxy: + """An expression-creation proxy for a dataset type's logical table. + + Parameters + ---------- + dataset_type : `str` + Dataset type name or wildcard. Wildcards are usable only when the + query contains exactly one dataset type or a wildcard. + + Notes + ----- + The attributes of this object are expression proxies for the fields + associated with datasets. + """ + + def __init__(self, dataset_type: str): + self._dataset_type = dataset_type + + def __repr__(self) -> str: + return self._dataset_type + + # Attributes are actually fixed, but we implement them with __getattr__ + # and __dir__ to avoid repeating the list. And someday they might expand + # to include Datastore record fields. + + def __getattr__(self, field: str) -> ExpressionProxy: + if field not in tree.DATASET_FIELD_NAMES: + raise AttributeError(field) + expression = tree.DatasetFieldReference(dataset_type=self._dataset_type, field=field) + match expression.column_type: + case "timespan": + return TimespanProxy(expression) + return ScalarExpressionProxy(expression) + + def __dir__(self) -> list[str]: + result = list(super().__dir__()) + result.extend(tree.DATASET_FIELD_NAMES) + return result + + +class ExpressionFactory: + """A factory for creating column expressions that uses operator overloading + to form a mini-language. + + Instances of this class are usually obtained from + `Query.expression_factory`; see that property's documentation for more + information. + + Parameters + ---------- + universe : `DimensionUniverse` + Object that describes all dimensions. + """ + + def __init__(self, universe: DimensionUniverse): + self._universe = universe + + def __getattr__(self, name: str) -> DimensionElementProxy: + element = self._universe.elements[name] + if element in self._universe.dimensions: + return DimensionProxy(cast(Dimension, element)) + return DimensionElementProxy(element) + + def __getitem__(self, name: str) -> DatasetTypeProxy: + return DatasetTypeProxy(name) + + def not_(self, operand: tree.Predicate) -> tree.Predicate: + """Apply a logical NOT operation to a boolean expression. + + Parameters + ---------- + operand : `tree.Predicate` + Expression to invetree. + + Returns + ------- + logical_not : `tree.Predicate` + A boolean expression that evaluates to the opposite of ``operand``. + """ + return operand.logical_not() + + def all(self, first: tree.Predicate, /, *args: tree.Predicate) -> tree.Predicate: + """Combine a sequence of boolean expressions with logical AND. + + Parameters + ---------- + first : `tree.Predicate` + First operand (required). + *args + Additional operands. + + Returns + ------- + logical_and : `tree.Predicate` + A boolean expression that evaluates to `True` only if all operands + evaluate to `True. + """ + return first.logical_and(*args) + + def any(self, first: tree.Predicate, /, *args: tree.Predicate) -> tree.Predicate: + """Combine a sequence of boolean expressions with logical OR. + + Parameters + ---------- + first : `tree.Predicate` + First operand (required). + *args + Additional operands. + + Returns + ------- + logical_or : `tree.Predicate` + A boolean expression that evaluates to `True` if any operand + evaluates to `True. + """ + return first.logical_or(*args) + + @staticmethod + def literal(value: object) -> ExpressionProxy: + """Return an expression proxy that represents a literal value. + + Expression proxy objects obtained from this factory can generally be + compared directly to literals, so calling this method directly in user + code should rarely be necessary. + + Parameters + ---------- + value : `object` + Value to include as a literal in an expression tree. + + Returns + ------- + expression : `ExpressionProxy` + Expression wrapper for this literal. + """ + expression = tree.make_column_literal(value) + match expression.expression_type: + case "timespan": + return TimespanProxy(expression) + case "region": + return RegionProxy(expression) + case "bool": + raise NotImplementedError("Boolean literals are not supported.") + case _: + return ScalarExpressionProxy(expression) + + @staticmethod + def unwrap(proxy: ExpressionProxy) -> tree.ColumnExpression: + """Return the column expression object that backs a proxy. + + Parameters + ---------- + proxy : `ExpressionProxy` + Proxy constructed via an `ExpressionFactory`. + + Returns + ------- + expression : `tree.ColumnExpression` + Underlying column expression object. + """ + return proxy._expression diff --git a/python/lsst/daf/butler/queries/overlaps.py b/python/lsst/daf/butler/queries/overlaps.py new file mode 100644 index 0000000000..2e68acd1b0 --- /dev/null +++ b/python/lsst/daf/butler/queries/overlaps.py @@ -0,0 +1,466 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("OverlapsVisitor",) + +import itertools +from collections.abc import Hashable, Iterable, Sequence, Set +from typing import Generic, Literal, TypeVar, cast + +from lsst.sphgeom import Region + +from .._topology import TopologicalFamily +from ..dimensions import DimensionElement, DimensionGroup +from . import tree +from .visitors import PredicateVisitFlags, SimplePredicateVisitor + +_T = TypeVar("_T", bound=Hashable) + + +class _NaiveDisjointSet(Generic[_T]): + """A very naive (but simple) implementation of a "disjoint set" data + structure for strings, with mostly O(N) performance. + + This class should not be used in any context where the number of elements + in the data structure is large. It intentionally implements a subset of + the interface of `scipy.cluster.DisJointSet` so that non-naive + implementation could be swapped in if desired. + + Parameters + ---------- + superset : `~collections.abc.Iterable` [ `str` ] + Elements to initialize the disjoint set, with each in its own + single-element subset. + """ + + def __init__(self, superset: Iterable[_T]): + self._subsets = [{k} for k in superset] + self._subsets.sort(key=len, reverse=True) + + def add(self, k: _T) -> bool: # numpydoc ignore=PR04 + """Add a new element as its own single-element subset unless it is + already present. + + Parameters + ---------- + k + Value to add. + + Returns + ------- + added : `bool`: + `True` if the value was actually added, `False` if it was already + present. + """ + for subset in self._subsets: + if k in subset: + return False + self._subsets.append({k}) + return True + + def merge(self, a: _T, b: _T) -> bool: # numpydoc ignore=PR04 + """Merge the subsets containing the given elements. + + Parameters + ---------- + a : + Element whose subset should be merged. + b : + Element whose subset should be merged. + + Returns + ------- + merged : `bool` + `True` if a merge occurred, `False` if the elements were already in + the same subset. + """ + for i, subset in enumerate(self._subsets): + if a in subset: + break + else: + raise KeyError(f"Merge argument {a!r} not in disjoin set {self._subsets}.") + for j, subset in enumerate(self._subsets): + if b in subset: + break + else: + raise KeyError(f"Merge argument {b!r} not in disjoin set {self._subsets}.") + if i == j: + return False + i, j = sorted((i, j)) + self._subsets[i].update(self._subsets[j]) + del self._subsets[j] + self._subsets.sort(key=len, reverse=True) + return True + + def subsets(self) -> Sequence[Set[_T]]: + """Return the current subsets, ordered from largest to smallest.""" + return self._subsets + + @property + def n_subsets(self) -> int: + """The number of subsets.""" + return len(self._subsets) + + +class OverlapsVisitor(SimplePredicateVisitor): + """A helper class for dealing with spatial and temporal overlaps in a + query. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions of the query. + + Notes + ----- + This class includes logic for extracting explicit spatial and temporal + joins from a WHERE-clause predicate and computing automatic joins given the + dimensions of the query. It is designed to be subclassed by query driver + implementations that want to rewrite the predicate at the same time. + """ + + def __init__(self, dimensions: DimensionGroup): + self.dimensions = dimensions + self._spatial_connections = _NaiveDisjointSet(self.dimensions.spatial) + self._temporal_connections = _NaiveDisjointSet(self.dimensions.temporal) + + def run(self, predicate: tree.Predicate, join_operands: Iterable[DimensionGroup]) -> tree.Predicate: + """Process the given predicate to extract spatial and temporal + overlaps. + + Parameters + ---------- + predicate : `tree.Predicate` + Predicate to process. + join_operands : `~collections.abc.Iterable` [ `DimensionGroup` ] + The dimensions of logical tables being joined into this query; + these can included embedded spatial and temporal joins that can + make it unnecessary to add new ones. + + Returns + ------- + predicate : `tree.Predicate` + A possibly-modified predicate that should replace the original. + """ + result = predicate.visit(self) + if result is None: + result = predicate + for join_operand_dimensions in join_operands: + self.add_join_operand_connections(join_operand_dimensions) + for a, b in self.compute_automatic_spatial_joins(): + join_predicate = self.visit_spatial_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS) + if join_predicate is None: + join_predicate = tree.Predicate.compare( + tree.DimensionFieldReference.model_construct(element=a, field="region"), + "overlaps", + tree.DimensionFieldReference.model_construct(element=b, field="region"), + ) + result = result.logical_and(join_predicate) + for a, b in self.compute_automatic_temporal_joins(): + join_predicate = self.visit_temporal_dimension_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS) + if join_predicate is None: + join_predicate = tree.Predicate.compare( + tree.DimensionFieldReference.model_construct(element=a, field="timespan"), + "overlaps", + tree.DimensionFieldReference.model_construct(element=b, field="timespan"), + ) + result = result.logical_and(join_predicate) + return result + + def visit_comparison( + self, + a: tree.ColumnExpression, + operator: tree.ComparisonOperator, + b: tree.ColumnExpression, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + if operator == "overlaps": + if a.column_type == "region": + return self.visit_spatial_overlap(a, b, flags) + elif b.column_type == "timespan": + return self.visit_temporal_overlap(a, b, flags) + else: + raise AssertionError(f"Unexpected column type {a.column_type} for overlap.") + return None + + def add_join_operand_connections(self, operand_dimensions: DimensionGroup) -> None: + """Add overlap connections implied by a table or subquery. + + Parameters + ---------- + operand_dimensions : `DimensionGroup` + Dimensions of of the table or subquery. + + Notes + ----- + We assume each join operand to a `tree.Select` has its own + complete set of spatial and temporal joins that went into generating + its rows. That will naturally be true for relations originating from + the butler database, like dataset searches and materializations, and if + it isn't true for a data ID upload, that would represent an intentional + association between non-overlapping things that we'd want to respect by + *not* adding a more restrictive automatic join. + """ + for a_family, b_family in itertools.pairwise(operand_dimensions.spatial): + self._spatial_connections.merge(a_family, b_family) + for a_family, b_family in itertools.pairwise(operand_dimensions.temporal): + self._temporal_connections.merge(a_family, b_family) + + def compute_automatic_spatial_joins(self) -> list[tuple[DimensionElement, DimensionElement]]: + """Return pairs of dimension elements that should be spatially joined. + + Returns + ------- + joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ] + Automatic joins. + + Notes + ----- + This takes into account explicit joins extracted by `run` and implicit + joins added by `add_join_operand_connections`, and only returns + additional joins if there is an unambiguous way to spatially connect + any dimensions that are not already spatially connected. Automatic + joins are always the most fine-grained join between sets of dimensions + (i.e. ``visit_detector_region`` and ``patch`` instead of ``visit`` and + ``tract``), but explicitly adding a coarser join between sets of + elements will prevent the fine-grained join from being added. + """ + return self._compute_automatic_joins("spatial", self._spatial_connections) + + def compute_automatic_temporal_joins(self) -> list[tuple[DimensionElement, DimensionElement]]: + """Return pairs of dimension elements that should be spatially joined. + + Returns + ------- + joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ] + Automatic joins. + + Notes + ----- + See `compute_automatic_spatial_joins` for information on how automatic + joins are determined. Joins to dataset validity ranges are never + automatic. + """ + return self._compute_automatic_joins("temporal", self._temporal_connections) + + def _compute_automatic_joins( + self, kind: Literal["spatial", "temporal"], connections: _NaiveDisjointSet[TopologicalFamily] + ) -> list[tuple[DimensionElement, DimensionElement]]: + if connections.n_subsets <= 1: + # All of the joins we need are already present. + return [] + if connections.n_subsets > 2: + raise tree.InvalidQueryError( + f"Too many disconnected sets of {kind} families for an automatic " + f"join: {connections.subsets()}. Add explicit {kind} joins to avoid this error." + ) + a_subset, b_subset = connections.subsets() + if len(a_subset) > 1 or len(b_subset) > 1: + raise tree.InvalidQueryError( + f"A {kind} join is needed between {a_subset} and {b_subset}, but which join to " + "add is ambiguous. Add an explicit spatial join to avoid this error." + ) + # We have a pair of families that are not explicitly or implicitly + # connected to any other families; add an automatic join between their + # most fine-grained members. + (a_family,) = a_subset + (b_family,) = b_subset + return [ + ( + cast(DimensionElement, a_family.choose(self.dimensions.elements, self.dimensions.universe)), + cast(DimensionElement, b_family.choose(self.dimensions.elements, self.dimensions.universe)), + ) + ] + + def visit_spatial_overlap( + self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + """Dispatch a spatial overlap comparison predicate to handlers. + + This method should rarely (if ever) need to be overridden. + + Parameters + ---------- + a : `tree.ColumnExpression` + First operand. + b : `tree.ColumnExpression` + Second operand. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + match a, b: + case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference( + element=b_element + ): + return self.visit_spatial_join(a_element, b_element, flags) + case tree.DimensionFieldReference(element=element), region_expression: + pass + case region_expression, tree.DimensionFieldReference(element=element): + pass + case _: + raise AssertionError(f"Unexpected arguments for spatial overlap: {a}, {b}.") + if region := region_expression.get_literal_value(): + return self.visit_spatial_constraint(element, region, flags) + raise AssertionError(f"Unexpected argument for spatial overlap: {region_expression}.") + + def visit_temporal_overlap( + self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + """Dispatch a temporal overlap comparison predicate to handlers. + + This method should rarely (if ever) need to be overridden. + + Parameters + ---------- + a : `tree.ColumnExpression`- + First operand. + b : `tree.ColumnExpression` + Second operand. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + match a, b: + case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference( + element=b_element + ): + return self.visit_temporal_dimension_join(a_element, b_element, flags) + case _: + # We don't bother differentiating any other kind of temporal + # comparison, because in all foreseeable database schemas we + # wouldn't have to do anything special with them, since they + # don't participate in automatic join calculations and they + # should be straightforwardly convertible to SQL. + return None + + def visit_spatial_join( + self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + """Handle a spatial overlap comparison between two dimension elements. + + The default implementation updates the set of known spatial connections + (for use by `compute_automatic_spatial_joins`) and returns `None`. + + Parameters + ---------- + a : `DimensionElement` + One element in the join. + b : `DimensionElement` + The other element in the join. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + if a.spatial == b.spatial: + raise tree.InvalidQueryError(f"Spatial join between {a} and {b} is not necessary.") + self._spatial_connections.merge( + cast(TopologicalFamily, a.spatial), cast(TopologicalFamily, b.spatial) + ) + return None + + def visit_spatial_constraint( + self, + element: DimensionElement, + region: Region, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + """Handle a spatial overlap comparison between a dimension element and + a literal region. + + The default implementation just returns `None`. + + Parameters + ---------- + element : `DimensionElement` + The dimension element in the comparison. + region : `lsst.sphgeom.Region` + The literal region in the comparison. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + return None + + def visit_temporal_dimension_join( + self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + """Handle a temporal overlap comparison between two dimension elements. + + The default implementation updates the set of known temporal + connections (for use by `compute_automatic_temporal_joins`) and returns + `None`. + + Parameters + ---------- + a : `DimensionElement` + One element in the join. + b : `DimensionElement` + The other element in the join. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + if a.temporal == b.temporal: + raise tree.InvalidQueryError(f"Temporal join between {a} and {b} is not necessary.") + self._temporal_connections.merge( + cast(TopologicalFamily, a.temporal), cast(TopologicalFamily, b.temporal) + ) + return None diff --git a/python/lsst/daf/butler/queries/result_specs.py b/python/lsst/daf/butler/queries/result_specs.py new file mode 100644 index 0000000000..d452c221a9 --- /dev/null +++ b/python/lsst/daf/butler/queries/result_specs.py @@ -0,0 +1,262 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "ResultSpecBase", + "DataCoordinateResultSpec", + "DimensionRecordResultSpec", + "DatasetRefResultSpec", +) + +from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import Annotated, Literal, TypeAlias, Union, cast + +import pydantic + +from ..dimensions import DimensionElement, DimensionGroup +from .tree import ColumnSet, DatasetFieldName, InvalidQueryError, OrderExpression, QueryTree + + +class ResultSpecBase(pydantic.BaseModel, ABC): + """Base class for all query-result specification objects. + + A result specification is a struct that is combined with a `QueryTree` to + represent a serializable query-results object. + """ + + result_type: str + """String literal that corresponds to a concrete derived type.""" + + order_by: tuple[OrderExpression, ...] = () + """Expressions to sort the rows by.""" + + offset: int = 0 + """Index of the first row to return.""" + + limit: int | None = None + """Maximum number of rows to return, or `None` for no bound.""" + + def validate_tree(self, tree: QueryTree) -> None: + """Check that this result object is consistent with a query tree. + + Parameters + ---------- + tree : `QueryTree` + Query tree that defines the joins and row-filtering that these + results will come from. + """ + spec = cast(ResultSpec, self) + if not spec.dimensions <= tree.dimensions: + raise InvalidQueryError( + f"Query result specification has dimensions {spec.dimensions} that are not a subset of the " + f"query's dimensions {tree.dimensions}." + ) + result_columns = spec.get_result_columns() + assert result_columns.dimensions == spec.dimensions, "enforced by ResultSpec implementations" + for dataset_type in result_columns.dataset_fields: + if dataset_type not in tree.datasets: + raise InvalidQueryError(f"Dataset {dataset_type!r} is not available from this query.") + order_by_columns = ColumnSet(spec.dimensions) + for term in spec.order_by: + term.gather_required_columns(order_by_columns) + if not (order_by_columns.dimensions <= spec.dimensions): + raise InvalidQueryError( + "Order-by expression may not reference columns that are not in the result dimensions." + ) + for dataset_type in order_by_columns.dataset_fields.keys(): + if dataset_type not in tree.datasets: + raise InvalidQueryError( + f"Dataset type {dataset_type!r} in order-by expression is not part of the query." + ) + + @property + def find_first_dataset(self) -> str | None: + """The dataset type for which find-first resolution is required, if + any. + """ + return None + + @abstractmethod + def get_result_columns(self) -> ColumnSet: + """Return the columns included in the actual result rows. + + This does not necessarily include all columns required by the + `order_by` terms that are also a part of this spec. + """ + raise NotImplementedError() + + +class DataCoordinateResultSpec(ResultSpecBase): + """Specification for a query that yields `DataCoordinate` objects.""" + + result_type: Literal["data_coordinate"] = "data_coordinate" + + dimensions: DimensionGroup + """The dimensions of the data IDs returned by this query.""" + + include_dimension_records: bool = False + """Whether the returned data IDs include dimension records.""" + + def get_result_columns(self) -> ColumnSet: + # Docstring inherited. + result = ColumnSet(self.dimensions) + if self.include_dimension_records: + for element_name in self.dimensions.elements: + element = self.dimensions.universe[element_name] + if not element.is_cached: + result.dimension_fields[element_name].update(element.schema.remainder.names) + return result + + +class DimensionRecordResultSpec(ResultSpecBase): + """Specification for a query that yields `DimensionRecord` objects.""" + + result_type: Literal["dimension_record"] = "dimension_record" + + element: DimensionElement + """The name and definition of the dimension records returned by this query. + """ + + @property + def dimensions(self) -> DimensionGroup: + """The dimensions that are required or implied (directly or indirectly) + by this dimension element. + """ + return self.element.minimal_group + + def get_result_columns(self) -> ColumnSet: + # Docstring inherited. + result = ColumnSet(self.element.minimal_group) + result.dimension_fields[self.element.name].update(self.element.schema.remainder.names) + result.drop_dimension_keys(self.element.minimal_group.names - self.element.dimensions.names) + return result + + +class DatasetRefResultSpec(ResultSpecBase): + """Specification for a query that yields `DatasetRef` objects.""" + + result_type: Literal["dataset_ref"] = "dataset_ref" + + dataset_type_name: str + """The dataset type name of the datasets returned by this query.""" + + dimensions: DimensionGroup + """The dimensions of the datasets returned by this query.""" + + storage_class_name: str + """The name of the storage class of the datasets returned by this query.""" + + include_dimension_records: bool = False + """Whether the data IDs returned by this query include dimension records. + """ + + find_first: bool + """Whether this query should resolve data ID duplicates according to the + order of the collections to be searched. + """ + + @property + def find_first_dataset(self) -> str | None: + # Docstring inherited. + return self.dataset_type_name if self.find_first else None + + def get_result_columns(self) -> ColumnSet: + # Docstring inherited. + result = ColumnSet(self.dimensions) + result.dataset_fields[self.dataset_type_name].update({"dataset_id", "run"}) + if self.include_dimension_records: + for element_name in self.dimensions.elements: + element = self.dimensions.universe[element_name] + if not element.is_cached: + result.dimension_fields[element_name].update(element.schema.remainder.names) + return result + + +class GeneralResultSpec(ResultSpecBase): + """Specification for a query that yields a table with + an explicit list of columns. + """ + + result_type: Literal["general"] = "general" + + dimensions: DimensionGroup + """The dimensions that span all fields returned by this query.""" + + dimension_fields: Mapping[str, set[str]] + """Dimension record fields included in this query.""" + + dataset_fields: Mapping[str, set[DatasetFieldName]] + """Dataset fields included in this query.""" + + find_first: bool + """Whether this query requires find-first resolution for a dataset. + + This can only be `True` if exactly one dataset type's fields are included + in the results. + """ + + @property + def find_first_dataset(self) -> str | None: + # Docstring inherited. + if self.find_first: + (dataset_type,) = self.dataset_fields.keys() + return dataset_type + return None + + def get_result_columns(self) -> ColumnSet: + # Docstring inherited. + result = ColumnSet(self.dimensions) + for element_name, fields_for_element in self.dimension_fields.items(): + result.dimension_fields[element_name].update(fields_for_element) + for dataset_type, fields_for_dataset in self.dataset_fields.items(): + result.dataset_fields[dataset_type].update(fields_for_dataset) + return result + + @pydantic.model_validator(mode="after") + def _validate(self) -> GeneralResultSpec: + if self.find_first and len(self.dataset_fields) != 1: + raise InvalidQueryError("find_first=True requires exactly one result dataset type.") + for element_name, fields_for_element in self.dimension_fields.items(): + if element_name not in self.dimensions.elements: + raise InvalidQueryError(f"Dimension element {element_name} is not in {self.dimensions}.") + if not fields_for_element: + raise InvalidQueryError( + f"Empty dimension element field set for {element_name!r} is not permitted." + ) + for dataset_type, fields_for_dataset in self.dataset_fields.items(): + if not fields_for_dataset: + raise InvalidQueryError(f"Empty dataset field set for {dataset_type!r} is not permitted.") + return self + + +ResultSpec: TypeAlias = Annotated[ + Union[DataCoordinateResultSpec, DimensionRecordResultSpec, DatasetRefResultSpec, GeneralResultSpec], + pydantic.Field(discriminator="result_type"), +] diff --git a/python/lsst/daf/butler/queries/tree/__init__.py b/python/lsst/daf/butler/queries/tree/__init__.py new file mode 100644 index 0000000000..e320695f62 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/__init__.py @@ -0,0 +1,40 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from ._base import * +from ._column_expression import * +from ._column_literal import * +from ._column_reference import * +from ._column_set import * +from ._predicate import * +from ._predicate import LogicalNot +from ._query_tree import * + +LogicalNot.model_rebuild() +del LogicalNot + +Predicate.model_rebuild() diff --git a/python/lsst/daf/butler/queries/tree/_base.py b/python/lsst/daf/butler/queries/tree/_base.py new file mode 100644 index 0000000000..1cd97d096c --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_base.py @@ -0,0 +1,188 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "QueryTreeBase", + "ColumnExpressionBase", + "DatasetFieldName", + "InvalidQueryError", + "DATASET_FIELD_NAMES", +) + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeAlias, TypeVar, cast, get_args + +import pydantic + +from ...column_spec import ColumnType + +if TYPE_CHECKING: + from ..visitors import ColumnExpressionVisitor + from ._column_literal import ColumnLiteral + from ._column_set import ColumnSet + + +# Type annotation for string literals that can be used as dataset fields in +# the public API. The 'collection' and 'run' fields are string collection +# names. Internal interfaces may define other dataset field strings (e.g. +# collection primary key values) and hence should use `str` rather than this +# type. +DatasetFieldName: TypeAlias = Literal["dataset_id", "ingest_date", "run", "collection", "timespan"] + +# Tuple of the strings that can be use as dataset fields in public APIs. +DATASET_FIELD_NAMES: tuple[DatasetFieldName, ...] = tuple(get_args(DatasetFieldName)) + +_T = TypeVar("_T") +_L = TypeVar("_L") +_A = TypeVar("_A") +_O = TypeVar("_O") + + +class InvalidQueryError(RuntimeError): + """Exception raised when a query is not valid.""" + + +class QueryTreeBase(pydantic.BaseModel): + """Base class for all non-primitive types in a query tree.""" + + model_config = pydantic.ConfigDict(frozen=True, extra="forbid", strict=True) + + +class ColumnExpressionBase(QueryTreeBase, ABC): + """Base class for objects that represent non-boolean column expressions in + a query tree. + + Notes + ----- + This is a closed hierarchy whose concrete, `~typing.final` derived classes + are members of the `ColumnExpression` union. That union should be used in + type annotations rather than the technically-open base class. + """ + + expression_type: str + """String literal corresponding to a concrete expression type.""" + + is_literal: ClassVar[bool] = False + """Whether this expression wraps a literal Python value.""" + + @property + @abstractmethod + def precedence(self) -> int: + """Operator precedence for this operation. + + Lower values bind more tightly, so parentheses are needed when printing + an expression where an operand has a higher value than the expression + itself. + """ + raise NotImplementedError() + + @property + @abstractmethod + def column_type(self) -> ColumnType: + """A string enumeration value representing the type of the column + expression. + """ + raise NotImplementedError() + + def get_literal_value(self) -> Any | None: + """Return the literal value wrapped by this expression, or `None` if + it is not a literal. + """ + return None + + @abstractmethod + def gather_required_columns(self, columns: ColumnSet) -> None: + """Add any columns required to evaluate this expression to the + given column set. + + Parameters + ---------- + columns : `ColumnSet` + Set of columns to modify in place. + """ + raise NotImplementedError() + + @abstractmethod + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + """Invoke the visitor interface. + + Parameters + ---------- + visitor : `ColumnExpressionVisitor` + Visitor to invoke a method on. + + Returns + ------- + result : `object` + Forwarded result from the visitor. + """ + raise NotImplementedError() + + +class ColumnLiteralBase(ColumnExpressionBase): + """Base class for objects that represent literal values as column + expressions in a query tree. + + Notes + ----- + This is a closed hierarchy whose concrete, `~typing.final` derived classes + are members of the `ColumnLiteral` union. That union should be used in + type annotations rather than the technically-open base class. The concrete + members of that union are only semi-public; they appear in the serialized + form of a column expression tree, but should only be constructed via the + `make_column_literal` factory function. All concrete members of the union + are also guaranteed to have a read-only ``value`` attribute holding the + wrapped literal, but it is unspecified whether that is a regular attribute + or a `property` (and hence cannot be type-annotated). + """ + + is_literal: ClassVar[bool] = True + """Whether this expression wraps a literal Python value.""" + + @property + def precedence(self) -> int: + # Docstring inherited. + return 0 + + def get_literal_value(self) -> Any: + # Docstring inherited. + return cast("ColumnLiteral", self).value + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + pass + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return cast(ColumnType, self.expression_type) + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited + return visitor.visit_literal(cast("ColumnLiteral", self)) diff --git a/python/lsst/daf/butler/queries/tree/_column_expression.py b/python/lsst/daf/butler/queries/tree/_column_expression.py new file mode 100644 index 0000000000..44085c58b5 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_column_expression.py @@ -0,0 +1,278 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "ColumnExpression", + "OrderExpression", + "UnaryExpression", + "BinaryExpression", + "Reversed", + "UnaryOperator", + "BinaryOperator", + "validate_order_expression", +) + +from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, TypeVar, Union, final + +import pydantic + +from ...column_spec import ColumnType +from ._base import ColumnExpressionBase, InvalidQueryError +from ._column_literal import ColumnLiteral +from ._column_reference import ColumnReference +from ._column_set import ColumnSet + +if TYPE_CHECKING: + from ..visitors import ColumnExpressionVisitor + + +_T = TypeVar("_T") + + +UnaryOperator: TypeAlias = Literal["-", "begin_of", "end_of"] +BinaryOperator: TypeAlias = Literal["+", "-", "*", "/", "%"] + + +@final +class UnaryExpression(ColumnExpressionBase): + """A unary operation on a column expression that returns a non-boolean + value. + """ + + expression_type: Literal["unary"] = "unary" + + operand: ColumnExpression + """Expression this one operates on.""" + + operator: UnaryOperator + """Operator this expression applies.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.operand.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 1 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + match self.operator: + case "-": + return self.operand.column_type + case "begin_of" | "end_of": + return "datetime" + raise AssertionError(f"Invalid unary expression operator {self.operator}.") + + def __str__(self) -> str: + s = str(self.operand) + match self.operator: + case "-": + return f"-{s}" + case "begin_of": + return f"{s}.begin" + case "end_of": + return f"{s}.end" + + @pydantic.model_validator(mode="after") + def _validate_types(self) -> UnaryExpression: + match (self.operator, self.operand.column_type): + case ("-", "int" | "float"): + pass + case ("begin_of" | "end_of", "timespan"): + pass + case _: + raise InvalidQueryError( + f"Invalid column type {self.operand.column_type} for operator {self.operator!r}." + ) + return self + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_unary_expression(self) + + +@final +class BinaryExpression(ColumnExpressionBase): + """A binary operation on column expressions that returns a non-boolean + value. + """ + + expression_type: Literal["binary"] = "binary" + + a: ColumnExpression + """Left-hand side expression this one operates on.""" + + b: ColumnExpression + """Right-hand side expression this one operates on.""" + + operator: BinaryOperator + """Operator this expression applies. + + Integer '/' and '%' are defined as in SQL, not Python (though the + definitions are the same for positive arguments). + """ + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.a.gather_required_columns(columns) + self.b.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + match self.operator: + case "*" | "/" | "%": + return 2 + case "+" | "-": + return 3 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return self.a.column_type + + def __str__(self) -> str: + a = str(self.a) + b = str(self.b) + match self.operator: + case "*" | "+": + if self.a.precedence > self.precedence: + a = f"({a})" + if self.b.precedence > self.precedence: + b = f"({b})" + case _: + if self.a.precedence >= self.precedence: + a = f"({a})" + if self.b.precedence >= self.precedence: + b = f"({b})" + return f"{a} {self.operator} {b}" + + @pydantic.model_validator(mode="after") + def _validate_types(self) -> BinaryExpression: + if self.a.column_type != self.b.column_type: + raise InvalidQueryError( + f"Column types for operator {self.operator} do not agree " + f"({self.a.column_type}, {self.b.column_type})." + ) + match (self.operator, self.a.column_type): + case ("+" | "-" | "*" | "/", "int" | "float"): + pass + case ("%", "int"): + pass + case _: + raise InvalidQueryError( + f"Invalid column type {self.a.column_type} for operator {self.operator!r}." + ) + return self + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_binary_expression(self) + + +# Union without Pydantic annotation for the discriminator, for use in nesting +# in other unions that will add that annotation. It's not clear whether it +# would work to just nest the annotated ones, but it seems safest not to rely +# on undocumented behavior. +_ColumnExpression: TypeAlias = Union[ + ColumnLiteral, + ColumnReference, + UnaryExpression, + BinaryExpression, +] + + +ColumnExpression: TypeAlias = Annotated[_ColumnExpression, pydantic.Field(discriminator="expression_type")] + + +@final +class Reversed(ColumnExpressionBase): + """A tag wrapper for `ColumnExpression` that indicates sorting in + reverse order. + """ + + expression_type: Literal["reversed"] = "reversed" + + operand: ColumnExpression + """Expression to sort on in reverse.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.operand.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + raise AssertionError("Order-reversed expressions can never be nested in other column expressions.") + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return self.operand.column_type + + def __str__(self) -> str: + return f"{self.operand} DESC" + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_reversed(self) + + +def validate_order_expression(expression: _ColumnExpression | Reversed) -> _ColumnExpression | Reversed: + """Check that a column expression can be used for sorting. + + Parameters + ---------- + expression : `OrderExpression` + Expression to check. + + Returns + ------- + expression : `OrderExpression` + The checked expression; returned to make this usable as a Pydantic + validator. + + Raises + ------ + InvalidQueryError + Raised if this expression is not one that can be used for sorting. + """ + if expression.column_type not in ("int", "string", "float", "datetime"): + raise InvalidQueryError(f"Column type {expression.column_type} of {expression} is not ordered.") + return expression + + +OrderExpression: TypeAlias = Annotated[ + Union[_ColumnExpression, Reversed], + pydantic.Field(discriminator="expression_type"), + pydantic.AfterValidator(validate_order_expression), +] diff --git a/python/lsst/daf/butler/queries/tree/_column_literal.py b/python/lsst/daf/butler/queries/tree/_column_literal.py new file mode 100644 index 0000000000..5f65b83511 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_column_literal.py @@ -0,0 +1,372 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "ColumnLiteral", + "make_column_literal", +) + +import uuid +import warnings +from base64 import b64decode, b64encode +from functools import cached_property +from typing import Literal, TypeAlias, Union, final + +import astropy.time +import erfa +from lsst.sphgeom import Region + +from ..._timespan import Timespan +from ...time_utils import TimeConverter +from ._base import ColumnLiteralBase + +LiteralValue: TypeAlias = Union[int, str, float, bytes, uuid.UUID, astropy.time.Time, Timespan, Region] + + +@final +class IntColumnLiteral(ColumnLiteralBase): + """A literal `int` value in a column expression.""" + + expression_type: Literal["int"] = "int" + + value: int + """The wrapped value after base64 encoding.""" + + @classmethod + def from_value(cls, value: int) -> IntColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `int` + Value to wrap. + + Returns + ------- + expression : `IntColumnLiteral` + Literal expression object. + """ + return cls.model_construct(value=value) + + def __str__(self) -> str: + return repr(self.value) + + +@final +class StringColumnLiteral(ColumnLiteralBase): + """A literal `str` value in a column expression.""" + + expression_type: Literal["string"] = "string" + + value: str + """The wrapped value after base64 encoding.""" + + @classmethod + def from_value(cls, value: str) -> StringColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `str` + Value to wrap. + + Returns + ------- + expression : `StrColumnLiteral` + Literal expression object. + """ + return cls.model_construct(value=value) + + def __str__(self) -> str: + return repr(self.value) + + +@final +class FloatColumnLiteral(ColumnLiteralBase): + """A literal `float` value in a column expression.""" + + expression_type: Literal["float"] = "float" + + value: float + """The wrapped value after base64 encoding.""" + + @classmethod + def from_value(cls, value: float) -> FloatColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `float` + Value to wrap. + + Returns + ------- + expression : `FloatColumnLiteral` + Literal expression object. + """ + return cls.model_construct(value=value) + + def __str__(self) -> str: + return repr(self.value) + + +@final +class HashColumnLiteral(ColumnLiteralBase): + """A literal `bytes` value representing a hash in a column expression. + + The original value is base64-encoded when serialized and decoded on first + use. + """ + + expression_type: Literal["hash"] = "hash" + + encoded: bytes + """The wrapped value after base64 encoding.""" + + @cached_property + def value(self) -> bytes: + """The wrapped value.""" + return b64decode(self.encoded) + + @classmethod + def from_value(cls, value: bytes) -> HashColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `bytes` + Value to wrap. + + Returns + ------- + expression : `HashColumnLiteral` + Literal expression object. + """ + return cls.model_construct(encoded=b64encode(value)) + + def __str__(self) -> str: + return "(bytes)" + + +@final +class UUIDColumnLiteral(ColumnLiteralBase): + """A literal `uuid.UUID` value in a column expression.""" + + expression_type: Literal["uuid"] = "uuid" + + value: uuid.UUID + + @classmethod + def from_value(cls, value: uuid.UUID) -> UUIDColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `uuid.UUID` + Value to wrap. + + Returns + ------- + expression : `UUIDColumnLiteral` + Literal expression object. + """ + return cls.model_construct(value=value) + + def __str__(self) -> str: + return str(self.value) + + +@final +class DateTimeColumnLiteral(ColumnLiteralBase): + """A literal `astropy.time.Time` value in a column expression. + + The time is converted into TAI nanoseconds since 1970-01-01 when serialized + and restored from that on first use. + """ + + expression_type: Literal["datetime"] = "datetime" + + nsec: int + """TAI nanoseconds since 1970-01-01.""" + + @cached_property + def value(self) -> astropy.time.Time: + """The wrapped value.""" + return TimeConverter().nsec_to_astropy(self.nsec) + + @classmethod + def from_value(cls, value: astropy.time.Time) -> DateTimeColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `astropy.time.Time` + Value to wrap. + + Returns + ------- + expression : `DateTimeColumnLiteral` + Literal expression object. + """ + return cls.model_construct(nsec=TimeConverter().astropy_to_nsec(value)) + + def __str__(self) -> str: + # Trap dubious year warnings in case we have timespans from + # simulated data in the future + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=erfa.ErfaWarning) + return self.value.tai.strftime("%Y-%m-%dT%H:%M:%S") + + +@final +class TimespanColumnLiteral(ColumnLiteralBase): + """A literal `Timespan` value in a column expression. + + The timespan bounds are converted into TAI nanoseconds since 1970-01-01 + when serialized and the timespan is restored from that on first use. + """ + + expression_type: Literal["timespan"] = "timespan" + + begin_nsec: int + """TAI nanoseconds since 1970-01-01 for the lower bound of the timespan + (inclusive). + """ + + end_nsec: int + """TAI nanoseconds since 1970-01-01 for the upper bound of the timespan + (exclusive). + """ + + @cached_property + def value(self) -> Timespan: + """The wrapped value.""" + return Timespan(None, None, _nsec=(self.begin_nsec, self.end_nsec)) + + @classmethod + def from_value(cls, value: Timespan) -> TimespanColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `..Timespan` + Value to wrap. + + Returns + ------- + expression : `TimespanColumnLiteral` + Literal expression object. + """ + return cls.model_construct(begin_nsec=value._nsec[0], end_nsec=value._nsec[1]) + + def __str__(self) -> str: + return str(self.value) + + +@final +class RegionColumnLiteral(ColumnLiteralBase): + """A literal `lsst.sphgeom.Region` value in a column expression. + + The region is encoded to base64 `bytes` when serialized, and decoded on + first use. + """ + + expression_type: Literal["region"] = "region" + + encoded: bytes + """The wrapped value after base64 encoding.""" + + @cached_property + def value(self) -> bytes: + """The wrapped value.""" + return Region.decode(b64decode(self.encoded)) + + @classmethod + def from_value(cls, value: Region) -> RegionColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `..Region` + Value to wrap. + + Returns + ------- + expression : `RegionColumnLiteral` + Literal expression object. + """ + return cls.model_construct(encoded=b64encode(value.encode())) + + def __str__(self) -> str: + return "(region)" + + +ColumnLiteral: TypeAlias = Union[ + IntColumnLiteral, + StringColumnLiteral, + FloatColumnLiteral, + HashColumnLiteral, + UUIDColumnLiteral, + DateTimeColumnLiteral, + TimespanColumnLiteral, + RegionColumnLiteral, +] + + +def make_column_literal(value: LiteralValue) -> ColumnLiteral: + """Construct a `ColumnLiteral` from the value it will wrap. + + Parameters + ---------- + value : `LiteralValue` + Value to wrap. + + Returns + ------- + expression : `ColumnLiteral` + Literal expression object. + """ + match value: + case int(): + return IntColumnLiteral.from_value(value) + case str(): + return StringColumnLiteral.from_value(value) + case float(): + return FloatColumnLiteral.from_value(value) + case uuid.UUID(): + return UUIDColumnLiteral.from_value(value) + case bytes(): + return HashColumnLiteral.from_value(value) + case astropy.time.Time(): + return DateTimeColumnLiteral.from_value(value) + case Timespan(): + return TimespanColumnLiteral.from_value(value) + case Region(): + return RegionColumnLiteral.from_value(value) + raise TypeError(f"Invalid type {type(value).__name__!r} of value {value!r} for column literal.") diff --git a/python/lsst/daf/butler/queries/tree/_column_reference.py b/python/lsst/daf/butler/queries/tree/_column_reference.py new file mode 100644 index 0000000000..44b29b505a --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_column_reference.py @@ -0,0 +1,173 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("ColumnReference", "DimensionKeyReference", "DimensionFieldReference", "DatasetFieldReference") + +from typing import TYPE_CHECKING, Literal, TypeAlias, TypeVar, Union, final + +import pydantic + +from ...column_spec import ColumnType +from ...dimensions import Dimension, DimensionElement +from ._base import ColumnExpressionBase, DatasetFieldName, InvalidQueryError + +if TYPE_CHECKING: + from ..visitors import ColumnExpressionVisitor + from ._column_set import ColumnSet + + +_T = TypeVar("_T") + + +@final +class DimensionKeyReference(ColumnExpressionBase): + """A column expression that references a dimension primary key column.""" + + expression_type: Literal["dimension_key"] = "dimension_key" + + dimension: Dimension + """Definition and name of this dimension.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + columns.update_dimensions(self.dimension.minimal_group) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 0 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return self.dimension.primary_key.type + + def __str__(self) -> str: + return self.dimension.name + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_dimension_key_reference(self) + + +@final +class DimensionFieldReference(ColumnExpressionBase): + """A column expression that references a dimension record column that is + not a primary key. + """ + + expression_type: Literal["dimension_field"] = "dimension_field" + + element: DimensionElement + """Definition and name of the dimension element.""" + + field: str + """Name of the field (i.e. column) in the element's logical table.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + columns.update_dimensions(self.element.minimal_group) + columns.dimension_fields[self.element.name].add(self.field) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 0 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return self.element.schema.remainder[self.field].type + + def __str__(self) -> str: + return f"{self.element}.{self.field}" + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_dimension_field_reference(self) + + @pydantic.model_validator(mode="after") + def _validate_field(self) -> DimensionFieldReference: + if self.field not in self.element.schema.remainder.names: + raise InvalidQueryError(f"Dimension field {self.element.name}.{self.field} does not exist.") + return self + + +@final +class DatasetFieldReference(ColumnExpressionBase): + """A column expression that references a column associated with a dataset + type. + """ + + expression_type: Literal["dataset_field"] = "dataset_field" + + dataset_type: str + """Name of the dataset type.""" + + field: DatasetFieldName + """Name of the field (i.e. column) in the dataset's logical table.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + columns.dataset_fields[self.dataset_type].add(self.field) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 0 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + match self.field: + case "dataset_id": + return "uuid" + case "ingest_date": + return "datetime" + case "run": + return "string" + case "collection": + return "string" + case "timespan": + return "timespan" + raise AssertionError(f"Invalid field {self.field!r} for dataset.") + + def __str__(self) -> str: + return f"{self.dataset_type}.{self.field}" + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_dataset_field_reference(self) + + +ColumnReference: TypeAlias = Union[ + DimensionKeyReference, + DimensionFieldReference, + DatasetFieldReference, +] diff --git a/python/lsst/daf/butler/queries/tree/_column_set.py b/python/lsst/daf/butler/queries/tree/_column_set.py new file mode 100644 index 0000000000..dc38b2ffcc --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_column_set.py @@ -0,0 +1,357 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("ColumnSet",) + +from collections.abc import Iterable, Iterator, Mapping, Set + +from ... import column_spec +from ...dimensions import DimensionGroup +from ...nonempty_mapping import NonemptyMapping + + +class ColumnSet: + """A set-like hierarchical container for the columns in a query. + + Parameters + ---------- + dimensions : `DimensionGroup` + The dimensions that bound the set of columns, and by default specify + the set of dimension key columns present. + + Notes + ----- + This class does not inherit from `collections.abc.Set` because that brings + in a lot of requirements we don't need (particularly interoperability with + other set-like objects). + + This class is iterable over tuples of ``(logical_table, field)``, where + ``logical_table`` is a dimension element name or dataset type name, and + ``field`` is a column associated with one of those, or `None` for dimension + key columns. Iteration order is guaranteed to be deterministic and to + start with all included dimension keys in `DimensionGroup.dimension_ + """ + + def __init__(self, dimensions: DimensionGroup) -> None: + self._dimensions = dimensions + self._removed_dimension_keys: set[str] = set() + self._dimension_fields: dict[str, set[str]] = {name: set() for name in dimensions.elements} + self._dataset_fields = NonemptyMapping[str, set[str]](set) + + @property + def dimensions(self) -> DimensionGroup: + """The dimensions that bound all columns in the set.""" + return self._dimensions + + @property + def dimension_fields(self) -> Mapping[str, set[str]]: + """Dimension record fields included in the set, grouped by dimension + element name. + + The keys of this mapping are always ``self.dimensions.elements``, and + nested sets may be empty. + """ + return self._dimension_fields + + @property + def dataset_fields(self) -> NonemptyMapping[str, set[str]]: + """Dataset fields included in the set, grouped by dataset type name. + + The keys of this mapping are just those that actually have nonempty + nested sets. + """ + return self._dataset_fields + + def __bool__(self) -> bool: + return bool(self._dimensions) or any(self._dataset_fields.values()) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ColumnSet): + return False + return ( + self._dimensions == other._dimensions + and self._removed_dimension_keys == other._removed_dimension_keys + and self._dimension_fields == other._dimension_fields + and self._dataset_fields == other._dataset_fields + ) + + def __str__(self) -> str: + return f"{{{', '.join(self.get_qualified_name(k, v) for k, v in self)}}}" + + def issubset(self, other: ColumnSet) -> bool: + """Test whether all columns in this set are also in another. + + Parameters + ---------- + other : `ColumnSet` + Set of columns to compare to. + + Returns + ------- + issubset : `bool` + Whether all columns in ``self`` are also in ``other``. + """ + return ( + (self._get_dimension_keys() <= other._get_dimension_keys()) + and all( + fields.issubset(other._dimension_fields.get(element_name, frozenset())) + for element_name, fields in self._dimension_fields.items() + ) + and all( + fields.issubset(other._dataset_fields.get(dataset_type, frozenset())) + for dataset_type, fields in self._dataset_fields.items() + ) + ) + + def issuperset(self, other: ColumnSet) -> bool: + """Test whether all columns another set are also in this one. + + Parameters + ---------- + other : `ColumnSet` + Set of columns to compare to. + + Returns + ------- + issuperset : `bool` + Whether all columns in ``other`` are also in ``self``. + """ + return other.issubset(self) + + def isdisjoint(self, other: ColumnSet) -> bool: + """Test whether there are no columns in both this set and another. + + Parameters + ---------- + other : `ColumnSet` + Set of columns to compare to. + + Returns + ------- + isdisjoint : `bool` + Whether there are any columns in both ``self`` and ``other``. + """ + return ( + self._get_dimension_keys().isdisjoint(other._get_dimension_keys()) + and all( + fields.isdisjoint(other._dimension_fields.get(element, frozenset())) + for element, fields in self._dimension_fields.items() + ) + and all( + fields.isdisjoint(other._dataset_fields.get(dataset_type, frozenset())) + for dataset_type, fields in self._dataset_fields.items() + ) + ) + + def copy(self) -> ColumnSet: + """Return a copy of this set. + + Returns + ------- + copy : `ColumnSet` + New column set that can be modified without changing the original. + """ + result = ColumnSet(self._dimensions) + for element_name, element_fields in self._dimension_fields.items(): + result._dimension_fields[element_name].update(element_fields) + for dataset_type, dataset_fields in self._dataset_fields.items(): + result._dataset_fields[dataset_type].update(dataset_fields) + return result + + def update_dimensions(self, dimensions: DimensionGroup) -> None: + """Add new dimensions to the set. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions to be included. + """ + if not dimensions.issubset(self._dimensions): + self._dimensions = dimensions.union(self._dimensions) + self._dimension_fields = { + name: self._dimension_fields.get(name, set()) for name in self._dimensions.elements + } + self._removed_dimension_keys.intersection_update(dimensions.names) + + def update(self, other: ColumnSet) -> None: + """Add columns from another set to this one. + + Parameters + ---------- + other : `ColumnSet` + Column set whose columns should be included in this one. + """ + self.update_dimensions(other.dimensions) + self._removed_dimension_keys.intersection_update(other._removed_dimension_keys) + for element_name, element_fields in other._dimension_fields.items(): + self._dimension_fields[element_name].update(element_fields) + for dataset_type, dataset_fields in other._dataset_fields.items(): + self._dataset_fields[dataset_type].update(dataset_fields) + + def drop_dimension_keys(self, names: Iterable[str]) -> ColumnSet: + """Remove the given dimension key columns from the set. + + Parameters + ---------- + names : `~collections.abc.Iterable` [ `str` ] + Names of the dimensions to remove. + + Returns + ------- + self : `ColumnSet` + This column set, modified in place. + """ + self._removed_dimension_keys.update(names) + return self + + def drop_implied_dimension_keys(self) -> ColumnSet: + """Remove dimension key columns that are implied by others. + + Returns + ------- + self : `ColumnSet` + This column set, modified in place. + """ + return self.drop_dimension_keys(self._dimensions.implied) + + def restore_dimension_keys(self) -> None: + """Restore all removed dimension key columns.""" + self._removed_dimension_keys.clear() + + def __iter__(self) -> Iterator[tuple[str, str | None]]: + for dimension_name in self._dimensions.data_coordinate_keys: + if dimension_name not in self._removed_dimension_keys: + yield dimension_name, None + # We iterate over DimensionElements and their DimensionRecord columns + # in order to make sure that's predictable. We might want to extract + # these query results positionally in some contexts. + for element_name in self._dimensions.elements: + element = self._dimensions.universe[element_name] + fields_for_element = self._dimension_fields[element_name] + for spec in element.schema.remainder: + if spec.name in fields_for_element: + yield element_name, spec.name + # We sort dataset types and their fields lexicographically just to keep + # our queries from having any dependence on set-iteration order. + for dataset_type in sorted(self._dataset_fields): + for field in sorted(self._dataset_fields[dataset_type]): + yield dataset_type, field + + def is_timespan(self, logical_table: str, field: str | None) -> bool: + """Test whether the given column is a timespan. + + Parameters + ---------- + logical_table : `str` + Name of the dimension element or dataset type the column belongs + to. + field : `str` or `None` + Column within the logical table, or `None` for dimension key + columns. + + Returns + ------- + is_timespan : `bool` + Whether this column is a timespan. + """ + return field == "timespan" + + @staticmethod + def get_qualified_name(logical_table: str, field: str | None) -> str: + """Return string that should be used to fully identify a column. + + Parameters + ---------- + logical_table : `str` + Name of the dimension element or dataset type the column belongs + to. + field : `str` or `None` + Column within the logical table, or `None` for dimension key + columns. + + Returns + ------- + name : `str` + Fully-qualified name. + """ + return logical_table if field is None else f"{logical_table}:{field}" + + def get_column_spec(self, logical_table: str, field: str | None) -> column_spec.ColumnSpec: + """Return a complete description of a column. + + Parameters + ---------- + logical_table : `str` + Name of the dimension element or dataset type the column belongs + to. + field : `str` or `None` + Column within the logical table, or `None` for dimension key + columns. + + Returns + ------- + spec : `.column_spec.ColumnSpec` + Description of the column. + """ + qualified_name = self.get_qualified_name(logical_table, field) + if field is None: + return self._dimensions.universe.dimensions[logical_table].primary_key.model_copy( + update=dict(name=qualified_name) + ) + if logical_table in self._dimension_fields: + return ( + self._dimensions.universe[logical_table] + .schema.all[field] + .model_copy(update=dict(name=qualified_name)) + ) + match field: + case "dataset_id": + return column_spec.UUIDColumnSpec.model_construct(name=qualified_name, nullable=False) + case "ingest_date": + return column_spec.DateTimeColumnSpec.model_construct(name=qualified_name) + case "run": + # TODO: string length matches the one defined in the + # CollectionManager implementations; we need to find a way to + # avoid hard-coding the value in multiple places. + return column_spec.StringColumnSpec.model_construct( + name=qualified_name, nullable=False, length=128 + ) + case "collection": + return column_spec.StringColumnSpec.model_construct( + name=qualified_name, nullable=False, length=128 + ) + case "timespan": + return column_spec.TimespanColumnSpec.model_construct(name=qualified_name, nullable=False) + raise AssertionError(f"Unrecognized column identifiers: {logical_table}, {field}.") + + def _get_dimension_keys(self) -> Set[str]: + if not self._removed_dimension_keys: + return self._dimensions.names + else: + return self._dimensions.names - self._removed_dimension_keys diff --git a/python/lsst/daf/butler/queries/tree/_predicate.py b/python/lsst/daf/butler/queries/tree/_predicate.py new file mode 100644 index 0000000000..ea5f6bb878 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_predicate.py @@ -0,0 +1,629 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "Predicate", + "PredicateLeaf", + "LogicalNotOperand", + "PredicateOperands", + "ComparisonOperator", +) + +import itertools +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Annotated, Iterable, Literal, TypeAlias, TypeVar, Union, cast, final + +import pydantic + +from ._base import InvalidQueryError, QueryTreeBase +from ._column_expression import ColumnExpression + +if TYPE_CHECKING: + from ..visitors import PredicateVisitFlags, PredicateVisitor + from ._column_set import ColumnSet + from ._query_tree import QueryTree + +ComparisonOperator: TypeAlias = Literal["==", "!=", "<", ">", ">=", "<=", "overlaps"] + + +_L = TypeVar("_L") +_A = TypeVar("_A") +_O = TypeVar("_O") + + +class PredicateLeafBase(QueryTreeBase, ABC): + """Base class for leaf nodes of the `Predicate` tree. + + This is a closed hierarchy whose concrete, `~typing.final` derived classes + are members of the `PredicateLeaf` union. That union should generally + be used in type annotations rather than the technically-open base class. + """ + + @abstractmethod + def gather_required_columns(self, columns: ColumnSet) -> None: + """Add any columns required to evaluate this predicate leaf to the + given column set. + + Parameters + ---------- + columns : `ColumnSet` + Set of columns to modify in place. + """ + raise NotImplementedError() + + def invert(self) -> PredicateLeaf: + """Return a new leaf that is the logical not of this one.""" + return LogicalNot.model_construct(operand=cast("LogicalNotOperand", self)) + + @abstractmethod + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + """Invoke the visitor interface. + + Parameters + ---------- + visitor : `PredicateVisitor` + Visitor to invoke a method on. + flags : `PredicateVisitFlags` + Flags that provide information about where this leaf appears in the + larger predicate tree. + + Returns + ------- + result : `object` + Forwarded result from the visitor. + """ + raise NotImplementedError() + + +@final +class Predicate(QueryTreeBase): + """A boolean column expression. + + Notes + ----- + Predicate is the only class representing a boolean column expression that + should be used outside of this module (though the objects it nests appear + in its serialized form and hence are not fully private). It provides + several `classmethod` factories for constructing those nested types inside + a `Predicate` instance, and `PredicateVisitor` subclasses should be used + to process them. + """ + + operands: PredicateOperands + """Nested tuple of operands, with outer items combined via AND and inner + items combined via OR. + """ + + @property + def column_type(self) -> Literal["bool"]: + """A string enumeration value representing the type of the column + expression. + """ + return "bool" + + @classmethod + def from_bool(cls, value: bool) -> Predicate: + """Construct a predicate that always evaluates to `True` or `False`. + + Parameters + ---------- + value : `bool` + Value the predicate should evaluate to. + + Returns + ------- + predicate : `Predicate` + Predicate that evaluates to the given boolean value. + """ + return cls.model_construct(operands=() if value else ((),)) + + @classmethod + def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate: + """Construct a predicate representing a binary comparison between + two non-boolean column expressions. + + Parameters + ---------- + a : `ColumnExpression` + First column expression in the comparison. + operator : `str` + Enumerated string representing the comparison operator to apply. + May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps". + b : `ColumnExpression` + Second column expression in the comparison. + + Returns + ------- + predicate : `Predicate` + Predicate representing the comparison. + """ + return cls._from_leaf(Comparison(a=a, operator=operator, b=b)) + + @classmethod + def is_null(cls, operand: ColumnExpression) -> Predicate: + """Construct a predicate that tests whether a column expression is + NULL. + + Parameters + ---------- + operand : `ColumnExpression` + Column expression to test. + + Returns + ------- + predicate : `Predicate` + Predicate representing the NULL check. + """ + return cls._from_leaf(IsNull(operand=operand)) + + @classmethod + def in_container(cls, member: ColumnExpression, container: Iterable[ColumnExpression]) -> Predicate: + """Construct a predicate that tests whether one column expression is + a member of a container of other column expressions. + + Parameters + ---------- + member : `ColumnExpression` + Column expression that may be a member of the container. + container : `~collections.abc.Iterable` [ `ColumnExpression` ] + Container of column expressions to test for membership in. + + Returns + ------- + predicate : `Predicate` + Predicate representing the membership test. + """ + return cls._from_leaf(InContainer(member=member, container=tuple(container))) + + @classmethod + def in_range( + cls, member: ColumnExpression, start: int = 0, stop: int | None = None, step: int = 1 + ) -> Predicate: + """Construct a predicate that tests whether an integer column + expression is part of a strided range. + + Parameters + ---------- + member : `ColumnExpression` + Column expression that may be a member of the range. + start : `int`, optional + Beginning of the range, inclusive. + stop : `int` or `None`, optional + End of the range, exclusive. + step : `int`, optional + Offset between values in the range. + + Returns + ------- + predicate : `Predicate` + Predicate representing the membership test. + """ + return cls._from_leaf(InRange(member=member, start=start, stop=stop, step=step)) + + @classmethod + def in_query(cls, member: ColumnExpression, column: ColumnExpression, query_tree: QueryTree) -> Predicate: + """Construct a predicate that tests whether a column expression is + present in a single-column projection of a query tree. + + Parameters + ---------- + member : `ColumnExpression` + Column expression that may be present in the query. + column : `ColumnExpression` + Column to project from the query. + query_tree : `QueryTree` + Query tree to select from. + + Returns + ------- + predicate : `Predicate` + Predicate representing the membership test. + """ + return cls._from_leaf(InQuery(member=member, column=column, query_tree=query_tree)) + + def gather_required_columns(self, columns: ColumnSet) -> None: + """Add any columns required to evaluate this predicate to the given + column set. + + Parameters + ---------- + columns : `ColumnSet` + Set of columns to modify in place. + """ + for or_group in self.operands: + for operand in or_group: + operand.gather_required_columns(columns) + + def logical_and(self, *args: Predicate) -> Predicate: + """Construct a predicate representing the logical AND of this predicate + and one or more others. + + Parameters + ---------- + *args : `Predicate` + Other predicates. + + Returns + ------- + predicate : `Predicate` + Predicate representing the logical AND. + """ + operands = self.operands + for arg in args: + operands = self._impl_and(operands, arg.operands) + if not all(operands): + # If any item in operands is an empty tuple (i.e. False), simplify. + operands = ((),) + return Predicate.model_construct(operands=operands) + + def logical_or(self, *args: Predicate) -> Predicate: + """Construct a predicate representing the logical OR of this predicate + and one or more others. + + Parameters + ---------- + *args : `Predicate` + Other predicates. + + Returns + ------- + predicate : `Predicate` + Predicate representing the logical OR. + """ + operands = self.operands + for arg in args: + operands = self._impl_or(operands, arg.operands) + return Predicate.model_construct(operands=operands) + + def logical_not(self) -> Predicate: + """Construct a predicate representing the logical NOT of this + predicate. + + Returns + ------- + predicate : `Predicate` + Predicate representing the logical NOT. + """ + new_operands: PredicateOperands = ((),) + for or_group in self.operands: + new_group: PredicateOperands = () + for leaf in or_group: + new_group = self._impl_and(new_group, ((leaf.invert(),),)) + new_operands = self._impl_or(new_operands, new_group) + return Predicate.model_construct(operands=new_operands) + + def __str__(self) -> str: + and_terms = [] + for or_group in self.operands: + match len(or_group): + case 0: + and_terms.append("False") + case 1: + and_terms.append(str(or_group[0])) + case _: + or_str = " OR ".join(str(operand) for operand in or_group) + if len(self.operands) > 1: + and_terms.append(f"({or_str})") + else: + and_terms.append(or_str) + if not and_terms: + return "True" + return " AND ".join(and_terms) + + def visit(self, visitor: PredicateVisitor[_A, _O, _L]) -> _A: + """Invoke the visitor interface. + + Parameters + ---------- + visitor : `PredicateVisitor` + Visitor to invoke a method on. + + Returns + ------- + result : `object` + Forwarded result from the visitor. + """ + return visitor._visit_logical_and(self.operands) + + @classmethod + def _from_leaf(cls, leaf: PredicateLeaf) -> Predicate: + return cls._from_or_group((leaf,)) + + @classmethod + def _from_or_group(cls, or_group: tuple[PredicateLeaf, ...]) -> Predicate: + return Predicate.model_construct(operands=(or_group,)) + + @classmethod + def _impl_and(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: + # We could simplify cases where both sides have some of the same leaf + # expressions; even using 'is' tests would simplify some cases where + # converting to conjunctive normal form twice leads to a lot of + # duplication, e.g. NOT ((A AND B) OR (C AND D)) or any kind of + # double-negation. Right now those cases seem pathological enough to + # be not worth our time. + return a + b if a is not b else a + + @classmethod + def _impl_or(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: + # Same comment re simplification as in _impl_and applies here. + return tuple([a_operand + b_operand for a_operand, b_operand in itertools.product(a, b)]) + + +@final +class LogicalNot(PredicateLeafBase): + """A boolean column expression that inverts its operand.""" + + predicate_type: Literal["not"] = "not" + + operand: LogicalNotOperand + """Upstream boolean expression to invert.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.operand.gather_required_columns(columns) + + def __str__(self) -> str: + return f"NOT {self.operand}" + + def invert(self) -> LogicalNotOperand: + # Docstring inherited. + return self.operand + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor._visit_logical_not(self.operand, flags) + + +@final +class IsNull(PredicateLeafBase): + """A boolean column expression that tests whether its operand is NULL.""" + + predicate_type: Literal["is_null"] = "is_null" + + operand: ColumnExpression + """Upstream expression to test.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.operand.gather_required_columns(columns) + + def __str__(self) -> str: + return f"{self.operand} IS NULL" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor.visit_is_null(self.operand, flags) + + +@final +class Comparison(PredicateLeafBase): + """A boolean columns expression formed by comparing two non-boolean + expressions. + """ + + predicate_type: Literal["comparison"] = "comparison" + + a: ColumnExpression + """Left-hand side expression for the comparison.""" + + b: ColumnExpression + """Right-hand side expression for the comparison.""" + + operator: ComparisonOperator + """Comparison operator.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.a.gather_required_columns(columns) + self.b.gather_required_columns(columns) + + def __str__(self) -> str: + return f"{self.a} {self.operator.upper()} {self.b}" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor.visit_comparison(self.a, self.operator, self.b, flags) + + @pydantic.model_validator(mode="after") + def _validate_column_types(self) -> Comparison: + if self.a.column_type != self.b.column_type: + raise InvalidQueryError( + f"Column types for comparison {self} do not agree " + f"({self.a.column_type}, {self.b.column_type})." + ) + match (self.operator, self.a.column_type): + case ("==" | "!=", _): + pass + case ("<" | ">" | ">=" | "<=", "int" | "string" | "float" | "datetime"): + pass + case ("overlaps", "region" | "timespan"): + pass + case _: + raise InvalidQueryError( + f"Invalid column type {self.a.column_type} for operator {self.operator!r}." + ) + return self + + +@final +class InContainer(PredicateLeafBase): + """A boolean column expression that tests whether one expression is a + member of an explicit sequence of other expressions. + """ + + predicate_type: Literal["in_container"] = "in_container" + + member: ColumnExpression + """Expression to test for membership.""" + + container: tuple[ColumnExpression, ...] + """Expressions representing the elements of the container.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.member.gather_required_columns(columns) + for item in self.container: + item.gather_required_columns(columns) + + def __str__(self) -> str: + return f"{self.member} IN [{', '.join(str(item) for item in self.container)}]" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor.visit_in_container(self.member, self.container, flags) + + @pydantic.model_validator(mode="after") + def _validate(self) -> InContainer: + if self.member.column_type == "timespan" or self.member.column_type == "region": + raise InvalidQueryError( + f"Timespan or region column {self.member} may not be used in IN expressions." + ) + if not all(item.column_type == self.member.column_type for item in self.container): + raise InvalidQueryError(f"Column types for membership test {self} do not agree.") + return self + + +@final +class InRange(PredicateLeafBase): + """A boolean column expression that tests whether its expression is + included in an integer range. + """ + + predicate_type: Literal["in_range"] = "in_range" + + member: ColumnExpression + """Expression to test for membership.""" + + start: int = 0 + """Inclusive lower bound for the range.""" + + stop: int | None = None + """Exclusive upper bound for the range.""" + + step: int = 1 + """Difference between values in the range.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.member.gather_required_columns(columns) + + def __str__(self) -> str: + s = f"{self.start if self.start else ''}:{self.stop if self.stop is not None else ''}" + if self.step != 1: + s = f"{s}:{self.step}" + return f"{self.member} IN {s}" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + return visitor.visit_in_range(self.member, self.start, self.stop, self.step, flags) + + @pydantic.model_validator(mode="after") + def _validate(self) -> InRange: + if self.member.column_type != "int": + raise InvalidQueryError(f"Column {self.member} is not an integer.") + if self.step < 1: + raise InvalidQueryError("Range step must be >= 1.") + if self.stop is not None and self.stop < self.start: + raise InvalidQueryError("Range stop must be >= start.") + return self + + +@final +class InQuery(PredicateLeafBase): + """A boolean column expression that tests whether its expression is + included single-column projection of a relation. + + This is primarily intended to be used on dataset ID columns, but it may + be useful for other columns as well. + """ + + predicate_type: Literal["in_query"] = "in_query" + + member: ColumnExpression + """Expression to test for membership.""" + + column: ColumnExpression + """Expression to extract from `query_tree`.""" + + query_tree: QueryTree + """Relation whose rows from `column` represent the container.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + # We're only gathering columns from the query_tree this predicate is + # attached to, not `self.column`, which belongs to `self.query_tree`. + self.member.gather_required_columns(columns) + + def __str__(self) -> str: + return f"{self.member} IN (query).{self.column}" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor.visit_in_query_tree(self.member, self.column, self.query_tree, flags) + + @pydantic.model_validator(mode="after") + def _validate_column_types(self) -> InQuery: + if self.member.column_type == "timespan" or self.member.column_type == "region": + raise InvalidQueryError( + f"Timespan or region column {self.member} may not be used in IN expressions." + ) + if self.member.column_type != self.column.column_type: + raise InvalidQueryError( + f"Column types for membership test {self} do not agree " + f"({self.member.column_type}, {self.column.column_type})." + ) + + from ._column_set import ColumnSet + + columns_required_in_tree = ColumnSet(self.query_tree.dimensions) + self.column.gather_required_columns(columns_required_in_tree) + if columns_required_in_tree.dimensions != self.query_tree.dimensions: + raise InvalidQueryError( + f"Column {self.column} requires dimensions {columns_required_in_tree.dimensions}, " + f"but query tree only has {self.query_tree.dimensions}." + ) + if not columns_required_in_tree.dataset_fields.keys() <= self.query_tree.datasets.keys(): + raise InvalidQueryError( + f"Column {self.column} requires dataset types " + f"{set(columns_required_in_tree.dataset_fields.keys())} that are not present in query tree." + ) + return self + + +LogicalNotOperand: TypeAlias = Union[ + IsNull, + Comparison, + InContainer, + InRange, + InQuery, +] +PredicateLeaf: TypeAlias = Annotated[ + Union[LogicalNotOperand, LogicalNot], pydantic.Field(discriminator="predicate_type") +] + +PredicateOperands: TypeAlias = tuple[tuple[PredicateLeaf, ...], ...] diff --git a/python/lsst/daf/butler/queries/tree/_query_tree.py b/python/lsst/daf/butler/queries/tree/_query_tree.py new file mode 100644 index 0000000000..ee715e629b --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_query_tree.py @@ -0,0 +1,314 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "QueryTree", + "make_identity_query_tree", + "DataCoordinateUploadKey", + "MaterializationKey", + "DatasetSearch", + "SerializedQueryTree", +) + +import uuid +from collections.abc import Mapping +from typing import TypeAlias, final + +import pydantic + +from ...dimensions import DimensionGroup, DimensionUniverse +from ...pydantic_utils import DeferredValidation +from ._base import InvalidQueryError, QueryTreeBase +from ._column_set import ColumnSet +from ._predicate import Predicate + +DataCoordinateUploadKey: TypeAlias = uuid.UUID + +MaterializationKey: TypeAlias = uuid.UUID + + +def make_identity_query_tree(universe: DimensionUniverse) -> QueryTree: + """Make an initial query tree with empty dimensions and a single logical + row. + + This method should be used by `Butler._query` to construct the initial + query tree. This tree is a useful initial state because it is the + identity for joins, in that joining any other query tree to the identity + yields that query tree. + + Parameters + ---------- + universe : `..DimensionUniverse` + Definitions for all dimensions. + + Returns + ------- + tree : `QueryTree` + A tree with empty dimensions. + """ + return QueryTree(dimensions=universe.empty.as_group()) + + +@final +class DatasetSearch(QueryTreeBase): + """Information about a dataset search joined into a query tree. + + The dataset type name is the key of the dictionary (in `QueryTree`) where + this type is used as a value. + """ + + collections: tuple[str, ...] + """The collections to search. + + Order matters if this dataset type is later referenced by a `FindFirst` + operation. Collection wildcards are always resolved before being included + in a dataset search. + """ + + dimensions: DimensionGroup + """The dimensions of the dataset type. + + This must match the dimensions of the dataset type as already defined in + the butler database, but this cannot generally be verified when a relation + tree is validated (since it requires a database query) and hence must be + checked later. + """ + + storage_class_name: str | None + """Name of the storage class to use when returning `DatasetRef` results. + + May be `None` if the dataset is only used as a constraint or to return + columns that do not include a full dataset type. + """ + + +@final +class QueryTree(QueryTreeBase): + """A declarative, serializable description of the row constraints and joins + in a butler query. + + Notes + ----- + A `QueryTree` is the struct that represents the serializable form of a + `Query` object, or one piece (with `ResultSpec` the other) of the + serializable form of a query results object. + + This class's attributes describe the columns that are "available" to be + returned or used in ``where`` or ``order_by`` expressions, but it does not + carry information about the columns that are actually included in result + rows, or what kind of butler primitive (e.g. `DataCoordinate` or + `DatasetRef`) those rows might be transformed into. + """ + + dimensions: DimensionGroup + """The dimensions whose keys are joined into the query. + """ + + datasets: Mapping[str, DatasetSearch] = pydantic.Field(default_factory=dict) + """Dataset searches that have been joined into the query.""" + + data_coordinate_uploads: Mapping[DataCoordinateUploadKey, DimensionGroup] = pydantic.Field( + default_factory=dict + ) + """Uploaded tables of data ID values that have been joined into the query. + """ + + materializations: Mapping[MaterializationKey, DimensionGroup] = pydantic.Field(default_factory=dict) + """Tables of result rows from other queries that have been stored + temporarily on the server. + """ + + predicate: Predicate = Predicate.from_bool(True) + """Boolean expression trees whose logical AND defines a row filter.""" + + def get_joined_dimension_groups(self) -> frozenset[DimensionGroup]: + """Return a set of the dimension groups of all data coordinate uploads, + dataset searches, and materializations. + """ + result: set[DimensionGroup] = set(self.data_coordinate_uploads.values()) + result.update(self.materializations.values()) + for dataset_spec in self.datasets.values(): + result.add(dataset_spec.dimensions) + return frozenset(result) + + def join_dimensions(self, dimensions: DimensionGroup) -> QueryTree: + """Return a new tree that includes additional dimensions. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions to include. + + Returns + ------- + result : `QueryTree` + A new tree with the additional dimensions. + """ + return self.model_copy(update=dict(dimensions=self.dimensions | dimensions)) + + def join_data_coordinate_upload( + self, key: DataCoordinateUploadKey, dimensions: DimensionGroup + ) -> QueryTree: + """Return a new tree that joins in an uploaded table of data ID values. + + Parameters + ---------- + key : `DataCoordinateUploadKey` + Unique identifier for this upload, as assigned by a `QueryDriver`. + dimensions : `DimensionGroup` + Dimensions of the data IDs. + + Returns + ------- + result : `QueryTree` + A new tree that joins in the data ID table. + """ + assert key not in self.data_coordinate_uploads, "Query should prevent doing the same upload twice." + data_coordinate_uploads = dict(self.data_coordinate_uploads) + data_coordinate_uploads[key] = dimensions + return self.model_copy( + update=dict( + dimensions=self.dimensions | dimensions, data_coordinate_uploads=data_coordinate_uploads + ) + ) + + def join_materialization(self, key: MaterializationKey, dimensions: DimensionGroup) -> QueryTree: + """Return a new tree that joins in temporarily stored results from + another query. + + Parameters + ---------- + key : `MaterializationKey` + Unique identifier for this materialization, as assigned by a + `QueryDriver`. + dimensions : `DimensionGroup` + The dimensions stored in the materialization. + + Returns + ------- + result : `QueryTree` + A new tree that joins in the materialization. + """ + assert key not in self.data_coordinate_uploads, "Query should prevent duplicate materialization." + materializations = dict(self.materializations) + materializations[key] = dimensions + return self.model_copy( + update=dict(dimensions=self.dimensions | dimensions, materializations=materializations) + ) + + def join_dataset(self, dataset_type: str, search: DatasetSearch) -> QueryTree: + """Return a new tree joins in a search for a dataset. + + Parameters + ---------- + dataset_type : `str` + Name of dataset type to join in. + search : `DatasetSearch` + Struct containing the collection search path and dataset type + dimensions. + + Returns + ------- + result : `QueryTree` + A new tree that joins in the dataset search. + + Notes + ----- + If this dataset type was already joined in, the new `DatasetSearch` + replaces the old one. + """ + datasets = dict(self.datasets) + datasets[dataset_type] = search + return self.model_copy(update=dict(dimensions=self.dimensions | search.dimensions, datasets=datasets)) + + def where(self, *terms: Predicate) -> QueryTree: + """Return a new tree that adds row filtering via a boolean column + expression. + + Parameters + ---------- + *terms : `Predicate` + Boolean column expressions that filter rows. Arguments are + combined with logical AND. + + Returns + ------- + result : `QueryTree` + A new tree that with row filtering. + + Raises + ------ + InvalidQueryTreeError + Raised if a column expression requires a dataset column that is not + already present in the query tree. + + Notes + ----- + If an expression references a dimension or dimension element that is + not already present in the query tree, it will be joined in, but + datasets must already be joined into a query tree in order to reference + their fields in expressions. + """ + predicate = self.predicate + columns = ColumnSet(self.dimensions) + for where_term in terms: + where_term.gather_required_columns(columns) + predicate = predicate.logical_and(where_term) + if not (columns.dataset_fields.keys() <= self.datasets.keys()): + raise InvalidQueryError( + f"Cannot reference dataset type(s) {columns.dataset_fields.keys() - self.datasets.keys()} " + "that have not been joined." + ) + return self.model_copy(update=dict(dimensions=columns.dimensions, predicate=predicate)) + + @pydantic.model_validator(mode="after") + def _validate_join_operands(self) -> QueryTree: + for dimensions in self.get_joined_dimension_groups(): + if not dimensions.issubset(self.dimensions): + raise InvalidQueryError( + f"Dimensions {dimensions} of join operand are not a " + f"subset of the query tree's dimensions {self.dimensions}." + ) + return self + + @pydantic.model_validator(mode="after") + def _validate_required_columns(self) -> QueryTree: + columns = ColumnSet(self.dimensions) + self.predicate.gather_required_columns(columns) + if not columns.dimensions.issubset(self.dimensions): + raise InvalidQueryError("Predicate requires dimensions beyond those in the query tree.") + if not columns.dataset_fields.keys() <= self.datasets.keys(): + raise InvalidQueryError("Predicate requires dataset columns that are not in the query tree.") + return self + + +class SerializedQueryTree(DeferredValidation[QueryTree]): + """A Pydantic-serializable wrapper for `QueryTree` that defers validation + to the `validated` method, allowing a `.DimensionUniverse` to be provided. + """ diff --git a/python/lsst/daf/butler/queries/visitors.py b/python/lsst/daf/butler/queries/visitors.py new file mode 100644 index 0000000000..8340389e19 --- /dev/null +++ b/python/lsst/daf/butler/queries/visitors.py @@ -0,0 +1,540 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "ColumnExpressionVisitor", + "PredicateVisitor", + "SimplePredicateVisitor", + "PredicateVisitFlags", +) + +import enum +from abc import abstractmethod +from typing import Generic, TypeVar, final + +from . import tree + +_T = TypeVar("_T") +_L = TypeVar("_L") +_A = TypeVar("_A") +_O = TypeVar("_O") + + +class PredicateVisitFlags(enum.Flag): + """Flags that provide information about the location of a predicate term + in the larger tree. + """ + + HAS_AND_SIBLINGS = enum.auto() + HAS_OR_SIBLINGS = enum.auto() + INVERTED = enum.auto() + + +class ColumnExpressionVisitor(Generic[_T]): + """A visitor interface for traversing a `ColumnExpression` tree. + + Notes + ----- + Unlike `Predicate`, the concrete column expression types need to be + public for various reasons, and hence the visitor interface uses them + directly in its arguments. + + This interface includes `Reversed` (which is part of the `OrderExpression` + union but not the `ColumnExpression` union) because it is simpler to have + just one visitor interface and disable support for it at runtime as + appropriate. + """ + + @abstractmethod + def visit_literal(self, expression: tree.ColumnLiteral) -> _T: + """Visit a column expression that wraps a literal value. + + Parameters + ---------- + expression : `tree.ColumnLiteral` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_dimension_key_reference(self, expression: tree.DimensionKeyReference) -> _T: + """Visit a column expression that represents a dimension column. + + Parameters + ---------- + expression : `tree.DimensionKeyReference` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_dimension_field_reference(self, expression: tree.DimensionFieldReference) -> _T: + """Visit a column expression that represents a dimension record field. + + Parameters + ---------- + expression : `tree.DimensionFieldReference` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_dataset_field_reference(self, expression: tree.DatasetFieldReference) -> _T: + """Visit a column expression that represents a dataset field. + + Parameters + ---------- + expression : `tree.DatasetFieldReference` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_unary_expression(self, expression: tree.UnaryExpression) -> _T: + """Visit a column expression that represents a unary operation. + + Parameters + ---------- + expression : `tree.UnaryExpression` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_binary_expression(self, expression: tree.BinaryExpression) -> _T: + """Visit a column expression that wraps a binary operation. + + Parameters + ---------- + expression : `tree.BinaryExpression` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_reversed(self, expression: tree.Reversed) -> _T: + """Visit a column expression that switches sort order from ascending + to descending. + + Parameters + ---------- + expression : `tree.Reversed` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + +class PredicateVisitor(Generic[_A, _O, _L]): + """A visitor interface for traversing a `Predicate`. + + Notes + ----- + The concrete `PredicateLeaf` types are only semi-public (they appear in + the serialized form of a `Predicate`, but their types should not generally + be referenced directly outside of the module in which they are defined. + As a result, visiting these objects unpacks their attributes into the + visit method arguments. + """ + + @abstractmethod + def visit_comparison( + self, + a: tree.ColumnExpression, + operator: tree.ComparisonOperator, + b: tree.ColumnExpression, + flags: PredicateVisitFlags, + ) -> _L: + """Visit a binary comparison between column expressions. + + Parameters + ---------- + a : `tree.ColumnExpression` + First column expression in the comparison. + operator : `str` + Enumerated string representing the comparison operator to apply. + May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps". + b : `tree.ColumnExpression` + Second column expression in the comparison. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_is_null(self, operand: tree.ColumnExpression, flags: PredicateVisitFlags) -> _L: + """Visit a predicate leaf that tests whether a column expression is + NULL. + + Parameters + ---------- + operand : `tree.ColumnExpression` + Column expression to test. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_in_container( + self, + member: tree.ColumnExpression, + container: tuple[tree.ColumnExpression, ...], + flags: PredicateVisitFlags, + ) -> _L: + """Visit a predicate leaf that tests whether a column expression is + a member of a container. + + Parameters + ---------- + member : `tree.ColumnExpression` + Column expression that may be a member of the container. + container : `~collections.abc.Iterable` [ `tree.ColumnExpression` ] + Container of column expressions to test for membership in. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_in_range( + self, + member: tree.ColumnExpression, + start: int, + stop: int | None, + step: int, + flags: PredicateVisitFlags, + ) -> _L: + """Visit a predicate leaf that tests whether a column expression is + a member of an integer range. + + Parameters + ---------- + member : `tree.ColumnExpression` + Column expression that may be a member of the range. + start : `int`, optional + Beginning of the range, inclusive. + stop : `int` or `None`, optional + End of the range, exclusive. + step : `int`, optional + Offset between values in the range. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_in_query_tree( + self, + member: tree.ColumnExpression, + column: tree.ColumnExpression, + query_tree: tree.QueryTree, + flags: PredicateVisitFlags, + ) -> _L: + """Visit a predicate leaf that tests whether a column expression is + a member of a container. + + Parameters + ---------- + member : `tree.ColumnExpression` + Column expression that may be present in the query. + column : `tree.ColumnExpression` + Column to project from the query. + query_tree : `QueryTree` + Query tree to select from. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def apply_logical_not(self, original: tree.PredicateLeaf, result: _L, flags: PredicateVisitFlags) -> _L: + """Apply a logical NOT to the result of visiting an inverted predicate + leaf. + + Parameters + ---------- + original : `PredicateLeaf` + The original operand of the logical NOT operation. + result : `object` + Implementation-defined result of visiting the operand. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. Never has `PredicateVisitFlags.INVERTED` set. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def apply_logical_or( + self, + originals: tuple[tree.PredicateLeaf, ...], + results: tuple[_L, ...], + flags: PredicateVisitFlags, + ) -> _O: + """Apply a logical OR operation to the result of visiting a `tuple` of + predicate leaf objects. + + Parameters + ---------- + originals : `tuple` [ `PredicateLeaf`, ... ] + Original leaf objects in the logical OR. + results : `tuple` [ `object`, ... ] + Result of visiting the leaf objects. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. Never has `PredicateVisitFlags.INVERTED` or + `PredicateVisitFlags.HAS_OR_SIBLINGS` set. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def apply_logical_and(self, originals: tree.PredicateOperands, results: tuple[_O, ...]) -> _A: + """Apply a logical AND operation to the result of visiting a nested + `tuple` of predicate leaf objects. + + Parameters + ---------- + originals : `tuple` [ `tuple` [ `PredicateLeaf`, ... ], ... ] + Nested tuple of predicate leaf objects, with inner tuples + corresponding to groups that should be combined with logical OR. + results : `tuple` [ `object`, ... ] + Result of visiting the leaf objects. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @final + def _visit_logical_not(self, operand: tree.LogicalNotOperand, flags: PredicateVisitFlags) -> _L: + return self.apply_logical_not( + operand, operand.visit(self, flags | PredicateVisitFlags.INVERTED), flags + ) + + @final + def _visit_logical_or(self, operands: tuple[tree.PredicateLeaf, ...], flags: PredicateVisitFlags) -> _O: + nested_flags = flags + if len(operands) > 1: + nested_flags |= PredicateVisitFlags.HAS_OR_SIBLINGS + return self.apply_logical_or( + operands, tuple([operand.visit(self, nested_flags) for operand in operands]), flags + ) + + @final + def _visit_logical_and(self, operands: tree.PredicateOperands) -> _A: + if len(operands) > 1: + nested_flags = PredicateVisitFlags.HAS_AND_SIBLINGS + else: + nested_flags = PredicateVisitFlags(0) + return self.apply_logical_and( + operands, tuple([self._visit_logical_or(or_group, nested_flags) for or_group in operands]) + ) + + +class SimplePredicateVisitor( + PredicateVisitor[tree.Predicate | None, tree.Predicate | None, tree.Predicate | None] +): + """An intermediate base class for predicate visitor implementations that + either return `None` or a new `Predicate`. + + Notes + ----- + This class implements all leaf-node visitation methods to return `None`, + which is interpreted by the ``apply*`` method implementations as indicating + that the leaf is unmodified. Subclasses can thus override only certain + visitation methods and either return `None` if there is no result, or + return a replacement `Predicate` to construct a new tree. + """ + + def visit_comparison( + self, + a: tree.ColumnExpression, + operator: tree.ComparisonOperator, + b: tree.ColumnExpression, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def visit_is_null( + self, operand: tree.ColumnExpression, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def visit_in_container( + self, + member: tree.ColumnExpression, + container: tuple[tree.ColumnExpression, ...], + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def visit_in_range( + self, + member: tree.ColumnExpression, + start: int, + stop: int | None, + step: int, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def visit_in_query_tree( + self, + member: tree.ColumnExpression, + column: tree.ColumnExpression, + query_tree: tree.QueryTree, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def apply_logical_not( + self, original: tree.PredicateLeaf, result: tree.Predicate | None, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + # Docstring inherited. + if result is None: + return None + from . import tree + + return tree.Predicate._from_leaf(original).logical_not() + + def apply_logical_or( + self, + originals: tuple[tree.PredicateLeaf, ...], + results: tuple[tree.Predicate | None, ...], + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + if all(result is None for result in results): + return None + from . import tree + + return tree.Predicate.from_bool(False).logical_or( + *[ + tree.Predicate._from_leaf(original) if result is None else result + for original, result in zip(originals, results) + ] + ) + + def apply_logical_and( + self, + originals: tree.PredicateOperands, + results: tuple[tree.Predicate | None, ...], + ) -> tree.Predicate | None: + # Docstring inherited. + if all(result is None for result in results): + return None + from . import tree + + return tree.Predicate.from_bool(True).logical_and( + *[ + tree.Predicate._from_or_group(original) if result is None else result + for original, result in zip(originals, results) + ] + ) diff --git a/python/lsst/daf/butler/remote_butler/_remote_butler.py b/python/lsst/daf/butler/remote_butler/_remote_butler.py index ed8b8afe29..c6ada028ec 100644 --- a/python/lsst/daf/butler/remote_butler/_remote_butler.py +++ b/python/lsst/daf/butler/remote_butler/_remote_butler.py @@ -75,10 +75,9 @@ if TYPE_CHECKING: from .._file_dataset import FileDataset from .._limited_butler import LimitedButler - from .._query import Query from .._timespan import Timespan - from ..dimensions import DataId, DimensionGroup, DimensionRecord - from ..registry import CollectionArgType + from ..dimensions import DataId + from ..queries import Query from ..transfers import RepoExportContext @@ -585,55 +584,6 @@ def _query(self) -> AbstractContextManager[Query]: # Docstring inherited. raise NotImplementedError() - def _query_data_ids( - self, - dimensions: DimensionGroup | Iterable[str] | str, - *, - data_id: DataId | None = None, - where: str = "", - bind: Mapping[str, Any] | None = None, - expanded: bool = False, - order_by: Iterable[str] | str | None = None, - limit: int | None = None, - offset: int | None = None, - explain: bool = True, - **kwargs: Any, - ) -> list[DataCoordinate]: - # Docstring inherited. - raise NotImplementedError() - - def _query_datasets( - self, - dataset_type: Any, - collections: CollectionArgType | None = None, - *, - find_first: bool = True, - data_id: DataId | None = None, - where: str = "", - bind: Mapping[str, Any] | None = None, - expanded: bool = False, - explain: bool = True, - **kwargs: Any, - ) -> list[DatasetRef]: - # Docstring inherited. - raise NotImplementedError() - - def _query_dimension_records( - self, - element: str, - *, - data_id: DataId | None = None, - where: str = "", - bind: Mapping[str, Any] | None = None, - order_by: Iterable[str] | str | None = None, - limit: int | None = None, - offset: int | None = None, - explain: bool = True, - **kwargs: Any, - ) -> list[DimensionRecord]: - # Docstring inherited. - raise NotImplementedError() - def pruneDatasets( self, refs: Iterable[DatasetRef], diff --git a/python/lsst/daf/butler/tests/hybrid_butler.py b/python/lsst/daf/butler/tests/hybrid_butler.py index 55980c806c..22367b725e 100644 --- a/python/lsst/daf/butler/tests/hybrid_butler.py +++ b/python/lsst/daf/butler/tests/hybrid_butler.py @@ -40,13 +40,13 @@ from .._deferredDatasetHandle import DeferredDatasetHandle from .._file_dataset import FileDataset from .._limited_butler import LimitedButler -from .._query import Query from .._storage_class import StorageClass from .._timespan import Timespan from ..datastore import DatasetRefURIs from ..dimensions import DataCoordinate, DataId, DimensionGroup, DimensionRecord, DimensionUniverse from ..direct_butler import DirectButler -from ..registry import CollectionArgType, Registry +from ..queries import Query +from ..registry import Registry from ..remote_butler import RemoteButler from ..transfers import RepoExportContext from .hybrid_butler_registry import HybridButlerRegistry @@ -356,7 +356,7 @@ def _query_data_ids( def _query_datasets( self, dataset_type: Any, - collections: CollectionArgType | None = None, + collections: str | Iterable[str] | None = None, *, find_first: bool = True, data_id: DataId | None = None, diff --git a/tests/test_query_interface.py b/tests/test_query_interface.py new file mode 100644 index 0000000000..ab83da94e6 --- /dev/null +++ b/tests/test_query_interface.py @@ -0,0 +1,2004 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Tests for the public Butler._query interface and the Pydantic models that +back it using a mock column-expression visitor and a mock QueryDriver +implementation. + +These tests are entirely independent of which kind of butler or database +backend we're using. + +This is a very large test file because a lot of tests make use of those mocks, +but they're not so generally useful that I think they're worth putting in the +library proper. +""" + +from __future__ import annotations + +import dataclasses +import itertools +import unittest +import uuid +from collections.abc import Iterable, Iterator, Mapping, Sequence, Set +from typing import Any, cast + +import astropy.time +from lsst.daf.butler import ( + CollectionType, + DataCoordinate, + DataIdValue, + DatasetRef, + DatasetType, + DimensionGroup, + DimensionRecord, + DimensionRecordSet, + DimensionUniverse, + MissingDatasetTypeError, + NamedValueSet, + NoDefaultCollectionError, + Timespan, +) +from lsst.daf.butler.queries import ( + DataCoordinateQueryResults, + DatasetQueryResults, + DimensionRecordQueryResults, + Query, + SingleTypeDatasetQueryResults, +) +from lsst.daf.butler.queries import driver as qd +from lsst.daf.butler.queries import result_specs as qrs +from lsst.daf.butler.queries import tree as qt +from lsst.daf.butler.queries.expression_factory import ExpressionFactory +from lsst.daf.butler.queries.tree._column_expression import UnaryExpression +from lsst.daf.butler.queries.tree._predicate import PredicateLeaf, PredicateOperands +from lsst.daf.butler.queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor +from lsst.daf.butler.registry import CollectionSummary, DatasetTypeError +from lsst.daf.butler.registry.interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord +from lsst.sphgeom import DISJOINT, Mq3cPixelization + + +class _TestVisitor(PredicateVisitor[bool, bool, bool], ColumnExpressionVisitor[Any]): + """Test visitor for column expressions. + + This visitor evaluates column expressions using regular Python logic. + + Parameters + ---------- + dimension_keys : `~collections.abc.Mapping`, optional + Mapping from dimension name to the value it should be assigned by the + visitor. + dimension_fields : `~collections.abc.Mapping`, optional + Mapping from ``(dimension element name, field)`` tuple to the value it + should be assigned by the visitor. + dataset_fields : `~collections.abc.Mapping`, optional + Mapping from ``(dataset type name, field)`` tuple to the value it + should be assigned by the visitor. + query_tree_items : `~collections.abc.Set`, optional + Set that should be used as the right-hand side of element-in-query + predicates. + """ + + def __init__( + self, + dimension_keys: Mapping[str, Any] | None = None, + dimension_fields: Mapping[tuple[str, str], Any] | None = None, + dataset_fields: Mapping[tuple[str, str], Any] | None = None, + query_tree_items: Set[Any] = frozenset(), + ): + self.dimension_keys = dimension_keys or {} + self.dimension_fields = dimension_fields or {} + self.dataset_fields = dataset_fields or {} + self.query_tree_items = query_tree_items + + def visit_binary_expression(self, expression: qt.BinaryExpression) -> Any: + match expression.operator: + case "+": + return expression.a.visit(self) + expression.b.visit(self) + case "-": + return expression.a.visit(self) - expression.b.visit(self) + case "*": + return expression.a.visit(self) * expression.b.visit(self) + case "/": + match expression.column_type: + case "int": + return expression.a.visit(self) // expression.b.visit(self) + case "float": + return expression.a.visit(self) / expression.b.visit(self) + case "%": + return expression.a.visit(self) % expression.b.visit(self) + + def visit_comparison( + self, + a: qt.ColumnExpression, + operator: qt.ComparisonOperator, + b: qt.ColumnExpression, + flags: PredicateVisitFlags, + ) -> bool: + match operator: + case "==": + return a.visit(self) == b.visit(self) + case "!=": + return a.visit(self) != b.visit(self) + case "<": + return a.visit(self) < b.visit(self) + case ">": + return a.visit(self) > b.visit(self) + case "<=": + return a.visit(self) <= b.visit(self) + case ">=": + return a.visit(self) >= b.visit(self) + case "overlaps": + return not (a.visit(self).relate(b.visit(self)) & DISJOINT) + + def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> Any: + return self.dataset_fields[expression.dataset_type, expression.field] + + def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> Any: + return self.dimension_fields[expression.element.name, expression.field] + + def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> Any: + return self.dimension_keys[expression.dimension.name] + + def visit_in_container( + self, + member: qt.ColumnExpression, + container: tuple[qt.ColumnExpression, ...], + flags: PredicateVisitFlags, + ) -> bool: + return member.visit(self) in [item.visit(self) for item in container] + + def visit_in_range( + self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags + ) -> bool: + return member.visit(self) in range(start, stop, step) + + def visit_in_query_tree( + self, + member: qt.ColumnExpression, + column: qt.ColumnExpression, + query_tree: qt.QueryTree, + flags: PredicateVisitFlags, + ) -> bool: + return member.visit(self) in self.query_tree_items + + def visit_is_null(self, operand: qt.ColumnExpression, flags: PredicateVisitFlags) -> bool: + return operand.visit(self) is None + + def visit_literal(self, expression: qt.ColumnLiteral) -> Any: + return expression.get_literal_value() + + def visit_reversed(self, expression: qt.Reversed) -> Any: + return _TestReversed(expression.operand.visit(self)) + + def visit_unary_expression(self, expression: UnaryExpression) -> Any: + match expression.operator: + case "-": + return -expression.operand.visit(self) + case "begin_of": + return expression.operand.visit(self).begin + case "end_of": + return expression.operand.visit(self).end + + def apply_logical_and(self, originals: PredicateOperands, results: tuple[bool, ...]) -> bool: + return all(results) + + def apply_logical_not(self, original: PredicateLeaf, result: bool, flags: PredicateVisitFlags) -> bool: + return not result + + def apply_logical_or( + self, originals: tuple[PredicateLeaf, ...], results: tuple[bool, ...], flags: PredicateVisitFlags + ) -> bool: + return any(results) + + +@dataclasses.dataclass +class _TestReversed: + """Struct used by _TestVisitor" to mark an expression as reversed in sort + order. + """ + + operand: Any + + +class _TestQueryExecution(BaseException): + """Exception raised by _TestQueryDriver.execute to communicate its args + back to the caller. + """ + + def __init__(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree, driver: _TestQueryDriver) -> None: + self.result_spec = result_spec + self.tree = tree + self.driver = driver + + +class _TestQueryCount(BaseException): + """Exception raised by _TestQueryDriver.count to communicate its args + back to the caller. + """ + + def __init__( + self, + result_spec: qrs.ResultSpec, + tree: qt.QueryTree, + driver: _TestQueryDriver, + exact: bool, + discard: bool, + ) -> None: + self.result_spec = result_spec + self.tree = tree + self.driver = driver + self.exact = exact + self.discard = discard + + +class _TestQueryAny(BaseException): + """Exception raised by _TestQueryDriver.any to communicate its args + back to the caller. + """ + + def __init__( + self, + tree: qt.QueryTree, + driver: _TestQueryDriver, + exact: bool, + execute: bool, + ) -> None: + self.tree = tree + self.driver = driver + self.exact = exact + self.execute = execute + + +class _TestQueryExplainNoResults(BaseException): + """Exception raised by _TestQueryDriver.explain_no_results to communicate + its args back to the caller. + """ + + def __init__( + self, + tree: qt.QueryTree, + driver: _TestQueryDriver, + execute: bool, + ) -> None: + self.tree = tree + self.driver = driver + self.execute = execute + + +class _TestQueryDriver(qd.QueryDriver): + """Mock implementation of `QueryDriver` that mostly raises exceptions that + communicate the arguments its methods were called with. + + Parameters + ---------- + default_collections : `tuple` [ `str`, ... ], optional + Default collection the query or parent butler is imagined to have been + constructed with. + collection_info : `~collections.abc.Mapping`, optional + Mapping from collection name to its record and summary, simulating the + collections present in the data repository. + dataset_types : `~collections.abc.Mapping`, optional + Mapping from dataset type to its definition, simulating the dataset + types registered in the data repository. + result_rows : `tuple` [ `~collections.abc.Iterable`, ... ], optional + A tuple of iterables of arbitrary type to use as result rows any time + `execute` is called, with each nested iterable considered a separate + page. The result type is not checked for consistency with the result + spec. If this is not provided, `execute` will instead raise + `_TestQueryExecution`, and `fetch_page` will not do anything useful. + """ + + def __init__( + self, + default_collections: tuple[str, ...] | None = None, + collection_info: Mapping[str, tuple[CollectionRecord, CollectionSummary]] | None = None, + dataset_types: Mapping[str, DatasetType] | None = None, + result_rows: tuple[Iterable[Any], ...] | None = None, + ) -> None: + self._universe = DimensionUniverse() + # Mapping of the arguments passed to materialize, keyed by the UUID + # that that each call returned. + self.materializations: dict[ + qd.MaterializationKey, tuple[qt.QueryTree, DimensionGroup, frozenset[str]] + ] = {} + # Mapping of the arguments passed to upload_data_coordinates, keyed by + # the UUID that that each call returned. + self.data_coordinate_uploads: dict[ + qd.DataCoordinateUploadKey, tuple[DimensionGroup, list[tuple[DataIdValue, ...]]] + ] = {} + self._default_collections = default_collections + self._collection_info = collection_info or {} + self._dataset_types = dataset_types or {} + self._executions: list[tuple[qrs.ResultSpec, qt.QueryTree]] = [] + self._result_rows = result_rows + self._result_iters: dict[qd.PageKey, tuple[Iterable[Any], Iterator[Iterable[Any]]]] = {} + + @property + def universe(self) -> DimensionUniverse: + return self._universe + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + pass + + def execute(self, result_spec: qrs.ResultSpec, tree: qt.QueryTree) -> qd.ResultPage: + if self._result_rows is not None: + iterator = iter(self._result_rows) + current_rows = next(iterator, ()) + return self._make_next_page(result_spec, current_rows, iterator) + raise _TestQueryExecution(result_spec, tree, self) + + def fetch_next_page(self, result_spec: qrs.ResultSpec, key: qd.PageKey) -> qd.ResultPage: + if self._result_rows is not None: + return self._make_next_page(result_spec, *self._result_iters.pop(key)) + raise AssertionError("Test query driver not initialized for actual results.") + + def _make_next_page( + self, result_spec: qrs.ResultSpec, current_rows: Iterable[Any], iterator: Iterator[Iterable[Any]] + ) -> qd.ResultPage: + next_rows = list(next(iterator, ())) + if not next_rows: + next_key = None + else: + next_key = uuid.uuid4() + self._result_iters[next_key] = (next_rows, iterator) + match result_spec: + case qrs.DataCoordinateResultSpec(): + return qd.DataCoordinateResultPage(spec=result_spec, next_key=next_key, rows=current_rows) + case qrs.DimensionRecordResultSpec(): + return qd.DimensionRecordResultPage(spec=result_spec, next_key=next_key, rows=current_rows) + case qrs.DatasetRefResultSpec(): + return qd.DatasetRefResultPage(spec=result_spec, next_key=next_key, rows=current_rows) + case _: + raise NotImplementedError("Other query types not yet supported.") + + def materialize( + self, + tree: qt.QueryTree, + dimensions: DimensionGroup, + datasets: frozenset[str], + ) -> qd.MaterializationKey: + key = uuid.uuid4() + self.materializations[key] = (tree, dimensions, datasets) + return key + + def upload_data_coordinates( + self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]] + ) -> qd.DataCoordinateUploadKey: + key = uuid.uuid4() + self.data_coordinate_uploads[key] = (dimensions, frozenset(rows)) + return key + + def count( + self, + tree: qt.QueryTree, + result_spec: qrs.ResultSpec, + *, + exact: bool, + discard: bool, + ) -> int: + raise _TestQueryCount(result_spec, tree, self, exact, discard) + + def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool: + raise _TestQueryAny(tree, self, exact, execute) + + def explain_no_results(self, tree: qt.QueryTree, execute: bool) -> Iterable[str]: + raise _TestQueryExplainNoResults(tree, self, execute) + + def get_default_collections(self) -> tuple[str, ...]: + if self._default_collections is None: + raise NoDefaultCollectionError() + return self._default_collections + + def resolve_collection_path( + self, collections: Sequence[str], _done: set[str] | None = None + ) -> list[tuple[CollectionRecord, CollectionSummary]]: + if _done is None: + _done = set() + result: list[tuple[CollectionRecord, CollectionSummary]] = [] + for name in collections: + if name in _done: + continue + _done.add(name) + record, summary = self._collection_info[name] + if record.type is CollectionType.CHAINED: + result.extend( + self.resolve_collection_path(cast(ChainedCollectionRecord, record).children, _done=_done) + ) + else: + result.append((record, summary)) + return result + + def get_dataset_type(self, name: str) -> DatasetType: + try: + return self._dataset_types[name] + except KeyError: + raise MissingDatasetTypeError(name) + + +class ColumnExpressionsTestCase(unittest.TestCase): + """Tests for column expression objects in lsst.daf.butler.queries.tree.""" + + def setUp(self) -> None: + self.universe = DimensionUniverse() + self.x = ExpressionFactory(self.universe) + + def query(self, **kwargs: Any) -> Query: + """Make an initial Query object with the given kwargs used to + initialize the _TestQueryDriver. + """ + return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe)) + + def test_int_literals(self) -> None: + expr = self.x.unwrap(self.x.literal(5)) + self.assertEqual(expr.value, 5) + self.assertEqual(expr.get_literal_value(), 5) + self.assertEqual(expr.expression_type, "int") + self.assertEqual(expr.column_type, "int") + self.assertEqual(str(expr), "5") + self.assertTrue(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertFalse(columns) + self.assertEqual(expr.visit(_TestVisitor()), 5) + + def test_string_literals(self) -> None: + expr = self.x.unwrap(self.x.literal("five")) + self.assertEqual(expr.value, "five") + self.assertEqual(expr.get_literal_value(), "five") + self.assertEqual(expr.expression_type, "string") + self.assertEqual(expr.column_type, "string") + self.assertEqual(str(expr), "'five'") + self.assertTrue(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertFalse(columns) + self.assertEqual(expr.visit(_TestVisitor()), "five") + + def test_float_literals(self) -> None: + expr = self.x.unwrap(self.x.literal(0.5)) + self.assertEqual(expr.value, 0.5) + self.assertEqual(expr.get_literal_value(), 0.5) + self.assertEqual(expr.expression_type, "float") + self.assertEqual(expr.column_type, "float") + self.assertEqual(str(expr), "0.5") + self.assertTrue(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertFalse(columns) + self.assertEqual(expr.visit(_TestVisitor()), 0.5) + + def test_hash_literals(self) -> None: + expr = self.x.unwrap(self.x.literal(b"eleven")) + self.assertEqual(expr.value, b"eleven") + self.assertEqual(expr.get_literal_value(), b"eleven") + self.assertEqual(expr.expression_type, "hash") + self.assertEqual(expr.column_type, "hash") + self.assertEqual(str(expr), "(bytes)") + self.assertTrue(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertFalse(columns) + self.assertEqual(expr.visit(_TestVisitor()), b"eleven") + + def test_uuid_literals(self) -> None: + value = uuid.uuid4() + expr = self.x.unwrap(self.x.literal(value)) + self.assertEqual(expr.value, value) + self.assertEqual(expr.get_literal_value(), value) + self.assertEqual(expr.expression_type, "uuid") + self.assertEqual(expr.column_type, "uuid") + self.assertEqual(str(expr), str(value)) + self.assertTrue(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertFalse(columns) + self.assertEqual(expr.visit(_TestVisitor()), value) + + def test_datetime_literals(self) -> None: + value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") + expr = self.x.unwrap(self.x.literal(value)) + self.assertEqual(expr.value, value) + self.assertEqual(expr.get_literal_value(), value) + self.assertEqual(expr.expression_type, "datetime") + self.assertEqual(expr.column_type, "datetime") + self.assertEqual(str(expr), "2020-01-01T00:00:00") + self.assertTrue(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertFalse(columns) + self.assertEqual(expr.visit(_TestVisitor()), value) + + def test_timespan_literals(self) -> None: + begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") + end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") + value = Timespan(begin, end) + expr = self.x.unwrap(self.x.literal(value)) + self.assertEqual(expr.value, value) + self.assertEqual(expr.get_literal_value(), value) + self.assertEqual(expr.expression_type, "timespan") + self.assertEqual(expr.column_type, "timespan") + self.assertEqual(str(expr), "[2020-01-01T00:00:00, 2020-01-01T00:01:00)") + self.assertTrue(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertFalse(columns) + self.assertEqual(expr.visit(_TestVisitor()), value) + + def test_region_literals(self) -> None: + pixelization = Mq3cPixelization(10) + value = pixelization.quad(12058870) + expr = self.x.unwrap(self.x.literal(value)) + self.assertEqual(expr.value, value) + self.assertEqual(expr.get_literal_value(), value) + self.assertEqual(expr.expression_type, "region") + self.assertEqual(expr.column_type, "region") + self.assertEqual(str(expr), "(region)") + self.assertTrue(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertFalse(columns) + self.assertEqual(expr.visit(_TestVisitor()), value) + + def test_invalid_literal(self) -> None: + with self.assertRaisesRegex(TypeError, "Invalid type 'complex' of value 5j for column literal."): + self.x.literal(5j) + + def test_dimension_key_reference(self) -> None: + expr = self.x.unwrap(self.x.detector) + self.assertIsNone(expr.get_literal_value()) + self.assertEqual(expr.expression_type, "dimension_key") + self.assertEqual(expr.column_type, "int") + self.assertEqual(str(expr), "detector") + self.assertFalse(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) + self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 3})), 3) + + def test_dimension_field_reference(self) -> None: + expr = self.x.unwrap(self.x.detector.purpose) + self.assertIsNone(expr.get_literal_value()) + self.assertEqual(expr.expression_type, "dimension_field") + self.assertEqual(expr.column_type, "string") + self.assertEqual(str(expr), "detector.purpose") + self.assertFalse(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) + self.assertEqual(columns.dimension_fields["detector"], {"purpose"}) + with self.assertRaises(qt.InvalidQueryError): + qt.DimensionFieldReference(element=self.universe.dimensions["detector"], field="region") + self.assertEqual( + expr.visit(_TestVisitor(dimension_fields={("detector", "purpose"): "science"})), "science" + ) + + def test_dataset_field_reference(self) -> None: + expr = self.x.unwrap(self.x["raw"].ingest_date) + self.assertIsNone(expr.get_literal_value()) + self.assertEqual(expr.expression_type, "dataset_field") + self.assertEqual(str(expr), "raw.ingest_date") + self.assertFalse(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.empty.as_group()) + self.assertEqual(columns.dataset_fields["raw"], {"ingest_date"}) + self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="dataset_id").column_type, "uuid") + self.assertEqual( + qt.DatasetFieldReference(dataset_type="raw", field="collection").column_type, "string" + ) + self.assertEqual(qt.DatasetFieldReference(dataset_type="raw", field="run").column_type, "string") + self.assertEqual( + qt.DatasetFieldReference(dataset_type="raw", field="ingest_date").column_type, "datetime" + ) + self.assertEqual( + qt.DatasetFieldReference(dataset_type="raw", field="timespan").column_type, "timespan" + ) + value = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") + self.assertEqual(expr.visit(_TestVisitor(dataset_fields={("raw", "ingest_date"): value})), value) + + def test_unary_negation(self) -> None: + expr = self.x.unwrap(-self.x.visit.exposure_time) + self.assertIsNone(expr.get_literal_value()) + self.assertEqual(expr.expression_type, "unary") + self.assertEqual(expr.column_type, "float") + self.assertEqual(str(expr), "-visit.exposure_time") + self.assertFalse(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) + self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"}) + self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 2.0})), -2.0) + with self.assertRaises(qt.InvalidQueryError): + qt.UnaryExpression( + operand=qt.DimensionFieldReference( + element=self.universe.dimensions["detector"], field="purpose" + ), + operator="-", + ) + + def test_unary_timespan_begin(self) -> None: + expr = self.x.unwrap(self.x.visit.timespan.begin) + self.assertIsNone(expr.get_literal_value()) + self.assertEqual(expr.expression_type, "unary") + self.assertEqual(expr.column_type, "datetime") + self.assertEqual(str(expr), "visit.timespan.begin") + self.assertFalse(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) + self.assertEqual(columns.dimension_fields["visit"], {"timespan"}) + begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") + end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") + value = Timespan(begin, end) + self.assertEqual( + expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.begin + ) + with self.assertRaises(qt.InvalidQueryError): + qt.UnaryExpression( + operand=qt.DimensionFieldReference( + element=self.universe.dimensions["detector"], field="purpose" + ), + operator="begin_of", + ) + + def test_unary_timespan_end(self) -> None: + expr = self.x.unwrap(self.x.visit.timespan.end) + self.assertIsNone(expr.get_literal_value()) + self.assertEqual(expr.expression_type, "unary") + self.assertEqual(expr.column_type, "datetime") + self.assertEqual(str(expr), "visit.timespan.end") + self.assertFalse(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) + self.assertEqual(columns.dimension_fields["visit"], {"timespan"}) + begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") + end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") + value = Timespan(begin, end) + self.assertEqual(expr.visit(_TestVisitor(dimension_fields={("visit", "timespan"): value})), value.end) + with self.assertRaises(qt.InvalidQueryError): + qt.UnaryExpression( + operand=qt.DimensionFieldReference( + element=self.universe.dimensions["detector"], field="purpose" + ), + operator="end_of", + ) + + def test_binary_expression_float(self) -> None: + for proxy, string, value in [ + (self.x.visit.exposure_time + 15.0, "visit.exposure_time + 15.0", 45.0), + (self.x.visit.exposure_time - 10.0, "visit.exposure_time - 10.0", 20.0), + (self.x.visit.exposure_time * 6.0, "visit.exposure_time * 6.0", 180.0), + (self.x.visit.exposure_time / 30.0, "visit.exposure_time / 30.0", 1.0), + (15.0 + -self.x.visit.exposure_time, "15.0 + -visit.exposure_time", -15.0), + (10.0 - -self.x.visit.exposure_time, "10.0 - -visit.exposure_time", 40.0), + (6.0 * -self.x.visit.exposure_time, "6.0 * -visit.exposure_time", -180.0), + (30.0 / -self.x.visit.exposure_time, "30.0 / -visit.exposure_time", -1.0), + ((self.x.visit.exposure_time + 15.0) * 6.0, "(visit.exposure_time + 15.0) * 6.0", 270.0), + ((self.x.visit.exposure_time + 15.0) + 45.0, "visit.exposure_time + 15.0 + 45.0", 90.0), + ((self.x.visit.exposure_time + 15.0) / 5.0, "(visit.exposure_time + 15.0) / 5.0", 9.0), + # We don't need the parentheses we generate in the next one, but + # they're not a problem either. + ((self.x.visit.exposure_time + 15.0) - 60.0, "(visit.exposure_time + 15.0) - 60.0", -15.0), + (6.0 * (-self.x.visit.exposure_time - 15.0), "6.0 * (-visit.exposure_time - 15.0)", -270.0), + (60.0 + (-self.x.visit.exposure_time - 15.0), "60.0 + -visit.exposure_time - 15.0", 15.0), + (90.0 / (-self.x.visit.exposure_time - 15.0), "90.0 / (-visit.exposure_time - 15.0)", -2.0), + (60.0 - (-self.x.visit.exposure_time - 15.0), "60.0 - (-visit.exposure_time - 15.0)", 105.0), + ]: + with self.subTest(string=string): + expr = self.x.unwrap(proxy) + self.assertIsNone(expr.get_literal_value()) + self.assertEqual(expr.expression_type, "binary") + self.assertEqual(expr.column_type, "float") + self.assertEqual(str(expr), string) + self.assertFalse(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) + self.assertEqual(columns.dimension_fields["visit"], {"exposure_time"}) + self.assertEqual( + expr.visit(_TestVisitor(dimension_fields={("visit", "exposure_time"): 30.0})), value + ) + + def test_binary_modulus(self) -> None: + for proxy, string, value in [ + (self.x.visit.id % 2, "visit % 2", 1), + (52 % self.x.visit, "52 % visit", 2), + ]: + with self.subTest(string=string): + expr = self.x.unwrap(proxy) + self.assertIsNone(expr.get_literal_value()) + self.assertEqual(expr.expression_type, "binary") + self.assertEqual(expr.column_type, "int") + self.assertEqual(str(expr), string) + self.assertFalse(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) + self.assertFalse(columns.dimension_fields["visit"]) + self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"visit": 5})), value) + + def test_binary_expression_validation(self) -> None: + with self.assertRaises(qt.InvalidQueryError): + # No arithmetic operators on strings (we do not interpret + as + # concatenation). + self.x.instrument + "suffix" + with self.assertRaises(qt.InvalidQueryError): + # Mixed types are not supported, even when they both support the + # operator. + self.x.visit.exposure_time + self.x.detector + with self.assertRaises(qt.InvalidQueryError): + # No modulus for floats. + self.x.visit.exposure_time % 5.0 + + def test_reversed(self) -> None: + expr = self.x.detector.desc + self.assertIsNone(expr.get_literal_value()) + self.assertEqual(expr.expression_type, "reversed") + self.assertEqual(expr.column_type, "int") + self.assertEqual(str(expr), "detector DESC") + self.assertFalse(expr.is_literal) + columns = qt.ColumnSet(self.universe.empty.as_group()) + expr.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) + self.assertFalse(columns.dimension_fields["detector"]) + self.assertEqual(expr.visit(_TestVisitor(dimension_keys={"detector": 5})), _TestReversed(5)) + + def test_trivial_predicate(self) -> None: + """Test logical operations on trivial True/False predicates.""" + yes = qt.Predicate.from_bool(True) + no = qt.Predicate.from_bool(False) + maybe: qt.Predicate = self.x.detector == 5 + for predicate in [ + yes, + yes.logical_or(no), + no.logical_or(yes), + yes.logical_and(yes), + no.logical_not(), + yes.logical_or(maybe), + maybe.logical_or(yes), + ]: + self.assertEqual(predicate.column_type, "bool") + self.assertEqual(str(predicate), "True") + self.assertTrue(predicate.visit(_TestVisitor())) + self.assertEqual(predicate.operands, ()) + for predicate in [ + no, + yes.logical_and(no), + no.logical_and(yes), + no.logical_or(no), + yes.logical_not(), + no.logical_and(maybe), + maybe.logical_and(no), + ]: + self.assertEqual(predicate.column_type, "bool") + self.assertEqual(str(predicate), "False") + self.assertFalse(predicate.visit(_TestVisitor())) + self.assertEqual(predicate.operands, ((),)) + for predicate in [ + maybe, + yes.logical_and(maybe), + no.logical_or(maybe), + maybe.logical_not().logical_not(), + ]: + self.assertEqual(predicate.column_type, "bool") + self.assertEqual(str(predicate), "detector == 5") + self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"detector": 5}))) + self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"detector": 4}))) + self.assertEqual(len(predicate.operands), 1) + self.assertEqual(len(predicate.operands[0]), 1) + self.assertIs(predicate.operands[0][0], maybe.operands[0][0]) + + def test_comparison(self) -> None: + predicate: qt.Predicate + string: str + value: bool + for detector in (4, 5, 6): + for predicate, string, value in [ + (self.x.detector == 5, "detector == 5", detector == 5), + (self.x.detector != 5, "detector != 5", detector != 5), + (self.x.detector < 5, "detector < 5", detector < 5), + (self.x.detector > 5, "detector > 5", detector > 5), + (self.x.detector <= 5, "detector <= 5", detector <= 5), + (self.x.detector >= 5, "detector >= 5", detector >= 5), + (self.x.detector == 5, "detector == 5", detector == 5), + (self.x.detector != 5, "detector != 5", detector != 5), + (self.x.detector < 5, "detector < 5", detector < 5), + (self.x.detector > 5, "detector > 5", detector > 5), + (self.x.detector <= 5, "detector <= 5", detector <= 5), + (self.x.detector >= 5, "detector >= 5", detector >= 5), + ]: + with self.subTest(string=string, detector=detector): + self.assertEqual(predicate.column_type, "bool") + self.assertEqual(str(predicate), string) + columns = qt.ColumnSet(self.universe.empty.as_group()) + predicate.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) + self.assertFalse(columns.dimension_fields["detector"]) + self.assertEqual( + predicate.visit(_TestVisitor(dimension_keys={"detector": detector})), value + ) + inverted = predicate.logical_not() + self.assertEqual(inverted.column_type, "bool") + self.assertEqual(str(inverted), f"NOT {string}") + self.assertEqual( + inverted.visit(_TestVisitor(dimension_keys={"detector": detector})), not value + ) + columns = qt.ColumnSet(self.universe.empty.as_group()) + inverted.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) + self.assertFalse(columns.dimension_fields["detector"]) + + def test_overlap_comparison(self) -> None: + pixelization = Mq3cPixelization(10) + region1 = pixelization.quad(12058870) + predicate = self.x.visit.region.overlaps(region1) + self.assertEqual(predicate.column_type, "bool") + self.assertEqual(str(predicate), "visit.region OVERLAPS (region)") + columns = qt.ColumnSet(self.universe.empty.as_group()) + predicate.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) + self.assertEqual(columns.dimension_fields["visit"], {"region"}) + region2 = pixelization.quad(12058857) + self.assertFalse(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): region2}))) + inverted = predicate.logical_not() + self.assertEqual(inverted.column_type, "bool") + self.assertEqual(str(inverted), "NOT visit.region OVERLAPS (region)") + self.assertTrue(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): region2}))) + columns = qt.ColumnSet(self.universe.empty.as_group()) + inverted.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) + self.assertEqual(columns.dimension_fields["visit"], {"region"}) + + def test_invalid_comparison(self) -> None: + # Mixed type comparisons. + with self.assertRaises(qt.InvalidQueryError): + self.x.visit > "three" + with self.assertRaises(qt.InvalidQueryError): + self.x.visit > 3.0 + # Invalid operator for type. + with self.assertRaises(qt.InvalidQueryError): + self.x["raw"].dataset_id < uuid.uuid4() + + def test_is_null(self) -> None: + predicate = self.x.visit.region.is_null + self.assertEqual(predicate.column_type, "bool") + self.assertEqual(str(predicate), "visit.region IS NULL") + columns = qt.ColumnSet(self.universe.empty.as_group()) + predicate.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) + self.assertEqual(columns.dimension_fields["visit"], {"region"}) + self.assertTrue(predicate.visit(_TestVisitor(dimension_fields={("visit", "region"): None}))) + inverted = predicate.logical_not() + self.assertEqual(inverted.column_type, "bool") + self.assertEqual(str(inverted), "NOT visit.region IS NULL") + self.assertFalse(inverted.visit(_TestVisitor(dimension_fields={("visit", "region"): None}))) + inverted.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) + self.assertEqual(columns.dimension_fields["visit"], {"region"}) + + def test_in_container(self) -> None: + predicate: qt.Predicate = self.x.visit.in_iterable([3, 4, self.x.exposure.id]) + self.assertEqual(predicate.column_type, "bool") + self.assertEqual(str(predicate), "visit IN [3, 4, exposure]") + columns = qt.ColumnSet(self.universe.empty.as_group()) + predicate.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"])) + self.assertFalse(columns.dimension_fields["visit"]) + self.assertFalse(columns.dimension_fields["exposure"]) + self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2}))) + self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5}))) + inverted = predicate.logical_not() + self.assertEqual(inverted.column_type, "bool") + self.assertEqual(str(inverted), "NOT visit IN [3, 4, exposure]") + self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 2}))) + self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 2, "exposure": 5}))) + columns = qt.ColumnSet(self.universe.empty.as_group()) + inverted.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["visit", "exposure"])) + self.assertFalse(columns.dimension_fields["visit"]) + self.assertFalse(columns.dimension_fields["exposure"]) + with self.assertRaises(qt.InvalidQueryError): + # Regions (and timespans) not allowed in IN expressions, since that + # suggests topological logic we're not actually doing. We can't + # use ExpressionFactory because it prohibits this case with typing. + pixelization = Mq3cPixelization(10) + region = pixelization.quad(12058870) + qt.Predicate.in_container(self.x.unwrap(self.x.visit.region), [qt.make_column_literal(region)]) + with self.assertRaises(qt.InvalidQueryError): + # Mismatched types. + self.x.visit.in_iterable([3.5, 2.1]) + + def test_in_range(self) -> None: + predicate: qt.Predicate = self.x.visit.in_range(2, 8, 2) + self.assertEqual(predicate.column_type, "bool") + self.assertEqual(str(predicate), "visit IN 2:8:2") + columns = qt.ColumnSet(self.universe.empty.as_group()) + predicate.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) + self.assertFalse(columns.dimension_fields["visit"]) + self.assertTrue(predicate.visit(_TestVisitor(dimension_keys={"visit": 2}))) + self.assertFalse(predicate.visit(_TestVisitor(dimension_keys={"visit": 8}))) + inverted = predicate.logical_not() + self.assertEqual(inverted.column_type, "bool") + self.assertEqual(str(inverted), "NOT visit IN 2:8:2") + self.assertFalse(inverted.visit(_TestVisitor(dimension_keys={"visit": 2}))) + self.assertTrue(inverted.visit(_TestVisitor(dimension_keys={"visit": 8}))) + columns = qt.ColumnSet(self.universe.empty.as_group()) + inverted.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["visit"])) + self.assertFalse(columns.dimension_fields["visit"]) + with self.assertRaises(qt.InvalidQueryError): + # Only integer fields allowed. + self.x.visit.exposure_time.in_range(2, 4) + with self.assertRaises(qt.InvalidQueryError): + # Step must be positive. + self.x.visit.in_range(2, 4, -1) + with self.assertRaises(qt.InvalidQueryError): + # Stop must be >= start. + self.x.visit.in_range(2, 0) + + def test_in_query(self) -> None: + query = self.query().join_dimensions(["visit", "tract"]).where(skymap="s", tract=3) + predicate: qt.Predicate = self.x.exposure.in_query(self.x.visit, query) + self.assertEqual(predicate.column_type, "bool") + self.assertEqual(str(predicate), "exposure IN (query).visit") + columns = qt.ColumnSet(self.universe.empty.as_group()) + predicate.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["exposure"])) + self.assertFalse(columns.dimension_fields["exposure"]) + self.assertTrue( + predicate.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3})) + ) + self.assertFalse( + predicate.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3})) + ) + inverted = predicate.logical_not() + self.assertEqual(inverted.column_type, "bool") + self.assertEqual(str(inverted), "NOT exposure IN (query).visit") + self.assertFalse( + inverted.visit(_TestVisitor(dimension_keys={"exposure": 2}, query_tree_items={1, 2, 3})) + ) + self.assertTrue( + inverted.visit(_TestVisitor(dimension_keys={"exposure": 8}, query_tree_items={1, 2, 3})) + ) + columns = qt.ColumnSet(self.universe.empty.as_group()) + inverted.gather_required_columns(columns) + self.assertEqual(columns.dimensions, self.universe.conform(["exposure"])) + self.assertFalse(columns.dimension_fields["exposure"]) + with self.assertRaises(qt.InvalidQueryError): + # Regions (and timespans) not allowed in IN expressions, since that + # suggests topological logic we're not actually doing. We can't + # use ExpressionFactory because it prohibits this case with typing. + qt.Predicate.in_query( + self.x.unwrap(self.x.visit.region), self.x.unwrap(self.x.tract.region), query._tree + ) + with self.assertRaises(qt.InvalidQueryError): + # Mismatched types. + self.x.exposure.in_query(self.x.visit.exposure_time, query) + with self.assertRaises(qt.InvalidQueryError): + # Query column requires dimensions that are not in the query. + self.x.exposure.in_query(self.x.patch, query) + with self.assertRaises(qt.InvalidQueryError): + # Query column requires dataset type that is not in the query. + self.x["raw"].dataset_id.in_query(self.x["raw"].dataset_id, query) + + def test_complex_predicate(self) -> None: + """Test that predicates are converted to conjunctive normal form and + get parentheses in the right places when stringified. + """ + visitor = _TestVisitor(dimension_keys={"instrument": "i", "detector": 3, "visit": 6, "band": "r"}) + a: qt.Predicate = self.x.visit > 5 # will evaluate to True + b: qt.Predicate = self.x.detector != 3 # will evaluate to False + c: qt.Predicate = self.x.instrument == "i" # will evaluate to True + d: qt.Predicate = self.x.band == "g" # will evaluate to False + predicate: qt.Predicate + for predicate, string, value in [ + (a.logical_or(b), f"{a} OR {b}", True), + (a.logical_or(c), f"{a} OR {c}", True), + (b.logical_or(d), f"{b} OR {d}", False), + (a.logical_and(b), f"{a} AND {b}", False), + (a.logical_and(c), f"{a} AND {c}", True), + (b.logical_and(d), f"{b} AND {d}", False), + (self.x.any(a, b, c, d), f"{a} OR {b} OR {c} OR {d}", True), + (self.x.all(a, b, c, d), f"{a} AND {b} AND {c} AND {d}", False), + (a.logical_or(b).logical_and(c), f"({a} OR {b}) AND {c}", True), + (a.logical_and(b.logical_or(d)), f"{a} AND ({b} OR {d})", False), + (a.logical_and(b).logical_or(c), f"({a} OR {c}) AND ({b} OR {c})", True), + ( + a.logical_and(b).logical_or(c.logical_and(d)), + f"({a} OR {c}) AND ({a} OR {d}) AND ({b} OR {c}) AND ({b} OR {d})", + False, + ), + (a.logical_or(b).logical_not(), f"NOT {a} AND NOT {b}", False), + (a.logical_or(c).logical_not(), f"NOT {a} AND NOT {c}", False), + (b.logical_or(d).logical_not(), f"NOT {b} AND NOT {d}", True), + (a.logical_and(b).logical_not(), f"NOT {a} OR NOT {b}", True), + (a.logical_and(c).logical_not(), f"NOT {a} OR NOT {c}", False), + (b.logical_and(d).logical_not(), f"NOT {b} OR NOT {d}", True), + ( + self.x.not_(a.logical_or(b).logical_and(c)), + f"(NOT {a} OR NOT {c}) AND (NOT {b} OR NOT {c})", + False, + ), + ( + a.logical_and(b.logical_or(d)).logical_not(), + f"(NOT {a} OR NOT {b}) AND (NOT {a} OR NOT {d})", + True, + ), + ]: + with self.subTest(string=string): + self.assertEqual(str(predicate), string) + self.assertEqual(predicate.visit(visitor), value) + + def test_proxy_misc(self) -> None: + """Test miscellaneous things on various ExpressionFactory proxies.""" + self.assertEqual(str(self.x.visit_detector_region), "visit_detector_region") + self.assertEqual(str(self.x.visit.instrument), "instrument") + self.assertEqual(str(self.x["raw"]), "raw") + self.assertEqual(str(self.x["raw.ingest_date"]), "raw.ingest_date") + self.assertEqual( + str(self.x.visit.timespan.overlaps(self.x["raw"].timespan)), + "visit.timespan OVERLAPS raw.timespan", + ) + self.assertGreater( + set(dir(self.x["raw"])), {"dataset_id", "ingest_date", "collection", "run", "timespan"} + ) + self.assertGreater(set(dir(self.x.exposure)), {"seq_num", "science_program", "timespan"}) + with self.assertRaises(AttributeError): + self.x["raw"].seq_num + with self.assertRaises(AttributeError): + self.x.visit.horse + + +class QueryTestCase(unittest.TestCase): + """Tests for Query and *QueryResults objects in lsst.daf.butler.queries.""" + + def setUp(self) -> None: + self.maxDiff = None + self.universe = DimensionUniverse() + # We use ArrowTable as the storage class for all dataset types because + # it's got conversions that only require third-party packages we + # already require. + self.raw = DatasetType( + "raw", dimensions=self.universe.conform(["detector", "exposure"]), storageClass="ArrowTable" + ) + self.refcat = DatasetType( + "refcat", dimensions=self.universe.conform(["htm7"]), storageClass="ArrowTable" + ) + self.bias = DatasetType( + "bias", + dimensions=self.universe.conform(["detector"]), + storageClass="ArrowTable", + isCalibration=True, + ) + self.default_collections: list[str] | None = ["DummyCam/defaults"] + self.collection_info: dict[str, tuple[CollectionRecord, CollectionSummary]] = { + "DummyCam/raw/all": ( + RunRecord[int](1, name="DummyCam/raw/all"), + CollectionSummary(NamedValueSet({self.raw}), governors={"instrument": {"DummyCam"}}), + ), + "DummyCam/calib": ( + CollectionRecord[int](2, name="DummyCam/calib", type=CollectionType.CALIBRATION), + CollectionSummary(NamedValueSet({self.bias}), governors={"instrument": {"DummyCam"}}), + ), + "refcats": ( + RunRecord[int](3, name="refcats"), + CollectionSummary(NamedValueSet({self.refcat}), governors={}), + ), + "DummyCam/defaults": ( + ChainedCollectionRecord[int]( + 4, name="DummyCam/defaults", children=("DummyCam/raw/all", "DummyCam/calib", "refcats") + ), + CollectionSummary( + NamedValueSet({self.raw, self.refcat, self.bias}), governors={"instrument": {"DummyCam"}} + ), + ), + } + self.dataset_types = {"raw": self.raw, "refcat": self.refcat, "bias": self.bias} + + def query(self, **kwargs: Any) -> Query: + """Make an initial Query object with the given kwargs used to + initialize the _TestQueryDriver. + + The given kwargs override the test-case-attribute defaults. + """ + kwargs.setdefault("default_collections", self.default_collections) + kwargs.setdefault("collection_info", self.collection_info) + kwargs.setdefault("dataset_types", self.dataset_types) + return Query(_TestQueryDriver(**kwargs), qt.make_identity_query_tree(self.universe)) + + def test_dataset_join(self) -> None: + """Test queries that have had a dataset search explicitly joined in via + Query.join_dataset_search. + + Since this kind of query has a moderate amount of complexity, this is + where we get a lot of basic coverage that applies to all kinds of + queries, including: + + - getting data ID and dataset results (but not iterating over them); + - the 'any' and 'explain_no_results' methods; + - adding 'where' filters (but not expanding dimensions accordingly); + - materializations. + """ + + def check( + query: Query, + dimensions: DimensionGroup = self.raw.dimensions.as_group(), + has_storage_class: bool = True, + dataset_type_registered: bool = True, + ) -> None: + """Run a battery of tests on one of a set of very similar queries + constructed in different ways (see below). + """ + + def check_query_tree( + tree: qt.QueryTree, + dimensions: DimensionGroup = dimensions, + storage_class_name: str | None = self.raw.storageClass_name if has_storage_class else None, + ) -> None: + """Check the state of the QueryTree object that backs the Query + or a derived QueryResults object. + + Parameters + ---------- + tree : `lsst.daf.butler.queries.tree.QueryTree` + Object to test. + dimensions : `DimensionGroup` + Dimensions to expect in the `QueryTree`, not necessarily + including those in the test 'raw' dataset type. + storage_class_name : `bool`, optional + The storage class name the query is expected to have for + the test 'raw' dataset type. + """ + self.assertEqual(tree.dimensions, dimensions | self.raw.dimensions.as_group()) + self.assertEqual(str(tree.predicate), "raw.run == 'DummyCam/raw/all'") + self.assertFalse(tree.materializations) + self.assertFalse(tree.data_coordinate_uploads) + self.assertEqual(tree.datasets.keys(), {"raw"}) + self.assertEqual(tree.datasets["raw"].dimensions, self.raw.dimensions.as_group()) + self.assertEqual(tree.datasets["raw"].collections, ("DummyCam/defaults",)) + self.assertEqual(tree.datasets["raw"].storage_class_name, storage_class_name) + self.assertEqual( + tree.get_joined_dimension_groups(), frozenset({self.raw.dimensions.as_group()}) + ) + + def check_data_id_results(*args, query: Query, dimensions: DimensionGroup = dimensions) -> None: + """Construct a DataCoordinateQueryResults object from the query + with the given arguments and run a battery of tests on it. + + Parameters + ---------- + *args + Forwarded to `Query.data_ids`. + query : `Query` + Query to start from. + dimensions : `DimensionGroup`, optional + Dimensions the result data IDs should have. + """ + with self.assertRaises(_TestQueryExecution) as cm: + list(query.data_ids(*args)) + self.assertEqual( + cm.exception.result_spec, + qrs.DataCoordinateResultSpec(dimensions=dimensions), + ) + check_query_tree(cm.exception.tree, dimensions=dimensions) + + def check_dataset_results( + *args: Any, + query: Query, + find_first: bool = True, + storage_class_name: str = self.raw.storageClass_name, + ) -> None: + """Construct a DatasetQueryResults object from the query with + the given arguments and run a battery of tests on it. + + Parameters + ---------- + *args + Forwarded to `Query.datasets`. + query : `Query` + Query to start from. + find_first : `bool`, optional + Whether to do find-first resolution on the results. + storage_class_name : `str`, optional + Expected name of the storage class for the results. + """ + with self.assertRaises(_TestQueryExecution) as cm: + list(query.datasets(*args, find_first=find_first)) + self.assertEqual( + cm.exception.result_spec, + qrs.DatasetRefResultSpec( + dataset_type_name="raw", + dimensions=self.raw.dimensions.as_group(), + storage_class_name=storage_class_name, + find_first=find_first, + ), + ) + check_query_tree(cm.exception.tree, storage_class_name=storage_class_name) + + def check_materialization( + kwargs: Mapping[str, Any], + query: Query, + dimensions: DimensionGroup = dimensions, + has_dataset: bool = True, + ) -> None: + """Materialize the query with the given arguments and run a + battery of tests on the result. + + Parameters + ---------- + kwargs + Forwarded as keyword arguments to `Query.materialize`. + query : `Query` + Query to start from. + dimensions : `DimensionGroup`, optional + Dimensions to expect in the materialization and its derived + query. + has_dataset : `bool`, optional + Whether the query backed by the materialization should + still have the test 'raw' dataset joined in. + """ + # Materialize the query and check the query tree sent to the + # driver and the one in the materialized query. + with self.assertRaises(_TestQueryExecution) as cm: + list(query.materialize(**kwargs).data_ids()) + derived_tree = cm.exception.tree + self.assertEqual(derived_tree.dimensions, dimensions) + # Predicate should be materialized away; it no longer appears + # in the derived query. + self.assertEqual(str(derived_tree.predicate), "True") + self.assertFalse(derived_tree.data_coordinate_uploads) + if has_dataset: + # Dataset search is still there, even though its existence + # constraint is included in the materialization, because we + # might need to re-join for some result columns in a + # derived query. + self.assertTrue(derived_tree.datasets.keys(), {"raw"}) + self.assertEqual(derived_tree.datasets["raw"].dimensions, self.raw.dimensions.as_group()) + self.assertEqual(derived_tree.datasets["raw"].collections, ("DummyCam/defaults",)) + else: + self.assertFalse(derived_tree.datasets) + ((key, derived_tree_materialized_dimensions),) = derived_tree.materializations.items() + self.assertEqual(derived_tree_materialized_dimensions, dimensions) + ( + materialized_tree, + materialized_dimensions, + materialized_datasets, + ) = cm.exception.driver.materializations[key] + self.assertEqual(derived_tree_materialized_dimensions, materialized_dimensions) + if has_dataset: + self.assertEqual(materialized_datasets, {"raw"}) + else: + self.assertFalse(materialized_datasets) + check_query_tree(materialized_tree) + + # Actual logic for the check() function begins here. + + self.assertEqual(query.constraint_dataset_types, {"raw"}) + self.assertEqual(query.constraint_dimensions, self.raw.dimensions.as_group()) + + # Adding a constraint on a field for this dataset type should work + # (this constraint will be present in all downstream tests). + query = query.where(query.expression_factory["raw"].run == "DummyCam/raw/all") + with self.assertRaises(qt.InvalidQueryError): + # Adding constraint on a different dataset should not work. + query.where(query.expression_factory["refcat"].run == "refcats") + + # Data IDs, with dimensions defaulted. + check_data_id_results(query=query) + # Dimensions for data IDs the same as defaults. + check_data_id_results(["exposure", "detector"], query=query) + # Dimensions are a subset of the query dimensions. + check_data_id_results(["exposure"], query=query, dimensions=self.universe.conform(["exposure"])) + # Dimensions are a superset of the query dimensions. + check_data_id_results( + ["exposure", "detector", "visit"], + query=query, + dimensions=self.universe.conform(["exposure", "detector", "visit"]), + ) + # Dimensions are neither a superset nor a subset of the query + # dimensions. + check_data_id_results( + ["detector", "visit"], query=query, dimensions=self.universe.conform(["visit", "detector"]) + ) + # Dimensions are empty. + check_data_id_results([], query=query, dimensions=self.universe.conform([])) + + # Get DatasetRef results, with various arguments and defaulting. + if has_storage_class: + check_dataset_results("raw", query=query) + check_dataset_results("raw", query=query, find_first=True) + check_dataset_results("raw", ["DummyCam/defaults"], query=query) + check_dataset_results("raw", ["DummyCam/defaults"], query=query, find_first=True) + else: + with self.assertRaises(MissingDatasetTypeError): + query.datasets("raw") + with self.assertRaises(MissingDatasetTypeError): + query.datasets("raw", find_first=True) + with self.assertRaises(MissingDatasetTypeError): + query.datasets("raw", ["DummyCam/defaults"]) + with self.assertRaises(MissingDatasetTypeError): + query.datasets("raw", ["DummyCam/defaults"], find_first=True) + check_dataset_results(self.raw, query=query) + check_dataset_results(self.raw, query=query, find_first=True) + check_dataset_results(self.raw, ["DummyCam/defaults"], query=query) + check_dataset_results(self.raw, ["DummyCam/defaults"], query=query, find_first=True) + + # Changing collections at this stage is not allowed. + with self.assertRaises(qt.InvalidQueryError): + query.datasets("raw", collections=["DummyCam/calib"]) + + # Changing storage classes is allowed, if they're compatible. + check_dataset_results( + self.raw.overrideStorageClass("ArrowNumpy"), query=query, storage_class_name="ArrowNumpy" + ) + if dataset_type_registered: + with self.assertRaises(DatasetTypeError): + # Can't use overrideStorageClass, because it'll raise + # before the code we want to test can. + query.datasets(DatasetType("raw", self.raw.dimensions, "int")) + + # Check the 'any' and 'explain_no_results' methods on Query itself. + for execute, exact in itertools.permutations([False, True], 2): + with self.assertRaises(_TestQueryAny) as cm: + query.any(execute=execute, exact=exact) + self.assertEqual(cm.exception.execute, execute) + self.assertEqual(cm.exception.exact, exact) + check_query_tree(cm.exception.tree, dimensions) + with self.assertRaises(_TestQueryExplainNoResults): + query.explain_no_results() + check_query_tree(cm.exception.tree, dimensions) + + # Materialize the query with defaults. + check_materialization({}, query=query) + # Materialize the query with args that match defaults. + check_materialization({"dimensions": ["exposure", "detector"], "datasets": {"raw"}}, query=query) + # Materialize the query with a superset of the original dimensions. + check_materialization( + {"dimensions": ["exposure", "detector", "visit"]}, + query=query, + dimensions=self.universe.conform(["exposure", "visit", "detector"]), + ) + # Materialize the query with no datasets. + check_materialization( + {"dimensions": ["exposure", "detector"], "datasets": frozenset()}, + query=query, + has_dataset=False, + ) + # Materialize the query with no datasets and a subset of the + # dimensions. + check_materialization( + {"dimensions": ["exposure"], "datasets": frozenset()}, + query=query, + has_dataset=False, + dimensions=self.universe.conform(["exposure"]), + ) + # Materializing the query with a dataset that is not in the query + # is an error. + with self.assertRaises(qt.InvalidQueryError): + query.materialize(datasets={"refcat"}) + # Materializing the query with dimensions that are not a superset + # of any materialized dataset dimensions is an error. + with self.assertRaises(qt.InvalidQueryError): + query.materialize(dimensions=["exposure"], datasets={"raw"}) + + # Actual logic for test_dataset_joins starts here. + + # Default collections and existing dataset type name. + check(self.query().join_dataset_search("raw")) + # Default collections and existing DatasetType instance. + check(self.query().join_dataset_search(self.raw)) + # Manual collections and existing dataset type. + check( + self.query(default_collections=None).join_dataset_search("raw", collections=["DummyCam/defaults"]) + ) + check( + self.query(default_collections=None).join_dataset_search( + self.raw, collections=["DummyCam/defaults"] + ) + ) + # Dataset type does not exist, but dimensions provided. This will + # prohibit getting results without providing the dataset type + # later. + check( + self.query(dataset_types={}).join_dataset_search( + "raw", dimensions=self.universe.conform(["detector", "exposure"]) + ), + has_storage_class=False, + dataset_type_registered=False, + ) + # Dataset type does not exist, but a full dataset type was + # provided up front. + check(self.query(dataset_types={}).join_dataset_search(self.raw), dataset_type_registered=False) + + with self.assertRaises(MissingDatasetTypeError): + # Dataset type does not exist and no dimensions passed. + self.query(dataset_types={}).join_dataset_search("raw", collections=["DummyCam/raw/all"]) + with self.assertRaises(DatasetTypeError): + # Dataset type does exist and bad dimensions passed. + self.query().join_dataset_search( + "raw", collections=["DummyCam/raw/all"], dimensions=self.universe.conform(["visit"]) + ) + with self.assertRaises(TypeError): + # Dataset type object and dimensions were passed (illegal even if + # they agree) + self.query().join_dataset_search( + self.raw, + dimensions=self.raw.dimensions.as_group(), + ) + with self.assertRaises(TypeError): + # Bad type for dataset type argument. + self.query().join_dataset_search(3) + with self.assertRaises(DatasetTypeError): + # Changing dimensions is an error. + self.query(dataset_types={}).join_dataset_search( + "raw", dimensions=self.universe.conform(["patch"]) + ).datasets(self.raw) + + def test_dimension_record_results(self) -> None: + """Test queries that return dimension records. + + This includes tests for: + + - joining against uploaded data coordinates; + - counting result rows; + - expanding dimensions as needed for 'where' conditions; + - order_by, limit, and offset. + + It does not include the iteration methods of + DimensionRecordQueryResults, since those require a different mock + driver setup (see test_dimension_record_iteration). + """ + # Set up the base query-results object to test. + query = self.query() + x = query.expression_factory + self.assertFalse(query.constraint_dimensions) + query = query.where(x.skymap == "m") + self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap"])) + upload_rows = [ + DataCoordinate.standardize(instrument="DummyCam", visit=3, universe=self.universe), + DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe), + ] + raw_rows = frozenset([data_id.required_values for data_id in upload_rows]) + query = query.join_data_coordinates(upload_rows) + self.assertEqual(query.constraint_dimensions, self.universe.conform(["skymap", "visit"])) + results = query.dimension_records("patch") + results = results.where(x.tract == 4) + + # Define a closure to run tests on variants of the base query. + def check( + results: DimensionRecordQueryResults, + order_by: Any = (), + limit: int | None = None, + offset: int = 0, + ) -> list[str]: + results = results.order_by(*order_by).limit(limit, offset=offset) + self.assertEqual(results.element.name, "patch") + with self.assertRaises(_TestQueryExecution) as cm: + list(results) + tree = cm.exception.tree + self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4") + self.assertEqual(tree.dimensions, self.universe.conform(["visit", "patch"])) + self.assertFalse(tree.materializations) + self.assertFalse(tree.datasets) + ((key, upload_dimensions),) = tree.data_coordinate_uploads.items() + self.assertEqual(upload_dimensions, self.universe.conform(["visit"])) + self.assertEqual(cm.exception.driver.data_coordinate_uploads[key], (upload_dimensions, raw_rows)) + result_spec = cm.exception.result_spec + self.assertEqual(result_spec.result_type, "dimension_record") + self.assertEqual(result_spec.element, self.universe["patch"]) + self.assertEqual(result_spec.limit, limit) + self.assertEqual(result_spec.offset, offset) + for exact, discard in itertools.permutations([False, True], r=2): + with self.assertRaises(_TestQueryCount) as cm: + results.count(exact=exact, discard=discard) + self.assertEqual(cm.exception.result_spec, result_spec) + self.assertEqual(cm.exception.exact, exact) + self.assertEqual(cm.exception.discard, discard) + return [str(term) for term in result_spec.order_by] + + # Run the closure's tests on variants of the base query. + self.assertEqual(check(results), []) + self.assertEqual(check(results, limit=2), []) + self.assertEqual(check(results, offset=1), []) + self.assertEqual(check(results, limit=3, offset=3), []) + self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"]) + self.assertEqual( + check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc], offset=2), + ["patch.cell_x", "patch.cell_y DESC"], + ) + with self.assertRaises(qt.InvalidQueryError): + # Cannot upload empty list of data IDs. + query.join_data_coordinates([]) + with self.assertRaises(qt.InvalidQueryError): + # Cannot upload heterogeneous list of data IDs. + query.join_data_coordinates( + [ + DataCoordinate.make_empty(self.universe), + DataCoordinate.standardize(instrument="DummyCam", universe=self.universe), + ] + ) + + def test_dimension_record_iteration(self) -> None: + """Tests for DimensionRecordQueryResult iteration.""" + + def make_record(n: int) -> DimensionRecord: + return self.universe["patch"].RecordClass(skymap="m", tract=4, patch=n) + + result_rows = ( + [make_record(n) for n in range(3)], + [make_record(n) for n in range(3, 6)], + [make_record(10)], + ) + results = self.query(result_rows=result_rows).dimension_records("patch") + self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows))) + self.assertEqual( + list(results.iter_set_pages()), + [DimensionRecordSet(self.universe["patch"], rows) for rows in result_rows], + ) + self.assertEqual( + [table.column("id").to_pylist() for table in results.iter_table_pages()], + [list(range(3)), list(range(3, 6)), [10]], + ) + + def test_data_coordinate_results(self) -> None: + """Test queries that return data coordinates. + + This includes tests for: + + - counting result rows; + - expanding dimensions as needed for 'where' conditions; + - order_by, limit, and offset. + + It does not include the iteration methods of + DataCoordinateQueryResults, since those require a different mock + driver setup (see test_data_coordinate_iteration). More tests for + different inputs to DataCoordinateQueryResults construction are in + test_dataset_join. + """ + # Set up the base query-results object to test. + query = self.query() + x = query.expression_factory + self.assertFalse(query.constraint_dimensions) + query = query.where(x.skymap == "m") + results = query.data_ids(["patch", "band"]) + results = results.where(x.tract == 4) + + # Define a closure to run tests on variants of the base query. + def check( + results: DataCoordinateQueryResults, + order_by: Any = (), + limit: int | None = None, + offset: int = 0, + include_dimension_records: bool = False, + ) -> list[str]: + results = results.order_by(*order_by).limit(limit, offset=offset) + self.assertEqual(results.dimensions, self.universe.conform(["patch", "band"])) + with self.assertRaises(_TestQueryExecution) as cm: + list(results) + tree = cm.exception.tree + self.assertEqual(str(tree.predicate), "skymap == 'm' AND tract == 4") + self.assertEqual(tree.dimensions, self.universe.conform(["patch", "band"])) + self.assertFalse(tree.materializations) + self.assertFalse(tree.datasets) + self.assertFalse(tree.data_coordinate_uploads) + result_spec = cm.exception.result_spec + self.assertEqual(result_spec.result_type, "data_coordinate") + self.assertEqual(result_spec.dimensions, self.universe.conform(["patch", "band"])) + self.assertEqual(result_spec.include_dimension_records, include_dimension_records) + self.assertEqual(result_spec.limit, limit) + self.assertEqual(result_spec.offset, offset) + self.assertIsNone(result_spec.find_first_dataset) + for exact, discard in itertools.permutations([False, True], r=2): + with self.assertRaises(_TestQueryCount) as cm: + results.count(exact=exact, discard=discard) + self.assertEqual(cm.exception.result_spec, result_spec) + self.assertEqual(cm.exception.exact, exact) + self.assertEqual(cm.exception.discard, discard) + return [str(term) for term in result_spec.order_by] + + # Run the closure's tests on variants of the base query. + self.assertEqual(check(results), []) + self.assertEqual(check(results.with_dimension_records(), include_dimension_records=True), []) + self.assertEqual( + check(results.with_dimension_records().with_dimension_records(), include_dimension_records=True), + [], + ) + self.assertEqual(check(results, limit=2), []) + self.assertEqual(check(results, offset=1), []) + self.assertEqual(check(results, limit=3, offset=3), []) + self.assertEqual(check(results, order_by=[x.patch.cell_x]), ["patch.cell_x"]) + self.assertEqual( + check(results, order_by=[x.patch.cell_x, x.patch.cell_y.desc], offset=2), + ["patch.cell_x", "patch.cell_y DESC"], + ) + self.assertEqual( + check(results, order_by=["patch.cell_x", "-cell_y"], offset=2), + ["patch.cell_x", "patch.cell_y DESC"], + ) + + def test_data_coordinate_iteration(self) -> None: + """Tests for DataCoordinateQueryResult iteration.""" + + def make_data_id(n: int) -> DimensionRecord: + return DataCoordinate.standardize(skymap="m", tract=4, patch=n, universe=self.universe) + + result_rows = ( + [make_data_id(n) for n in range(3)], + [make_data_id(n) for n in range(3, 6)], + [make_data_id(10)], + ) + results = self.query(result_rows=result_rows).data_ids(["patch"]) + self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows))) + + def test_dataset_results(self) -> None: + """Test queries that return dataset refs. + + This includes tests for: + + - counting result rows; + - expanding dimensions as needed for 'where' conditions; + - chained results for multiple dataset types; + - different ways of passing a data ID to 'where' methods; + - order_by, limit, and offset. + + It does not include the iteration methods of the DatasetQueryResults + classes, since those require a different mock driver setup (see + test_dataset_iteration). More tests for different inputs to + SingleTypeDatasetQueryResults construction are in test_dataset_join. + """ + # Set up a few equivalent base query-results object to test. + query = self.query() + x = query.expression_factory + self.assertFalse(query.constraint_dimensions) + results1 = query.datasets(...).where(x.instrument == "DummyCam", visit=4) + results2 = query.datasets(..., collections=["DummyCam/defaults"]).where( + {"instrument": "DummyCam", "visit": 4} + ) + results3 = query.datasets(["raw", "bias", "refcat"]).where( + DataCoordinate.standardize(instrument="DummyCam", visit=4, universe=self.universe) + ) + + # Define a closure to handle single-dataset-type results. + def check_single_type( + results: SingleTypeDatasetQueryResults, + order_by: Any = (), + limit: int | None = None, + offset: int = 0, + include_dimension_records: bool = False, + ) -> list[str]: + results = results.order_by(*order_by).limit(limit, offset=offset) + self.assertIs(list(results.by_dataset_type())[0], results) + with self.assertRaises(_TestQueryExecution) as cm: + list(results) + tree = cm.exception.tree + self.assertEqual(str(tree.predicate), "instrument == 'DummyCam' AND visit == 4") + self.assertEqual( + tree.dimensions, + self.universe.conform(["visit"]).union(results.dataset_type.dimensions.as_group()), + ) + self.assertFalse(tree.materializations) + self.assertEqual(tree.datasets.keys(), {results.dataset_type.name}) + self.assertEqual(tree.datasets[results.dataset_type.name].collections, ("DummyCam/defaults",)) + self.assertEqual( + tree.datasets[results.dataset_type.name].dimensions, + results.dataset_type.dimensions.as_group(), + ) + self.assertEqual( + tree.datasets[results.dataset_type.name].storage_class_name, + results.dataset_type.storageClass_name, + ) + self.assertFalse(tree.data_coordinate_uploads) + result_spec = cm.exception.result_spec + self.assertEqual(result_spec.result_type, "dataset_ref") + self.assertEqual(result_spec.include_dimension_records, include_dimension_records) + self.assertEqual(result_spec.limit, limit) + self.assertEqual(result_spec.offset, offset) + self.assertEqual(result_spec.find_first_dataset, result_spec.dataset_type_name) + for exact, discard in itertools.permutations([False, True], r=2): + with self.assertRaises(_TestQueryCount) as cm: + results.count(exact=exact, discard=discard) + self.assertEqual(cm.exception.result_spec, result_spec) + self.assertEqual(cm.exception.exact, exact) + self.assertEqual(cm.exception.discard, discard) + with self.assertRaises(_TestQueryExecution) as cm: + list(results.data_ids) + self.assertEqual( + cm.exception.result_spec, + qrs.DataCoordinateResultSpec( + dimensions=results.dataset_type.dimensions.as_group(), + include_dimension_records=include_dimension_records, + ), + ) + self.assertIs(cm.exception.tree, tree) + return [str(term) for term in result_spec.order_by] + + # Define a closure to run tests on variants of the base query, which + # is a chain of multiple dataset types. + def check_chained( + results: DatasetQueryResults, + order_by: tuple[Any, Any, Any] = ((), (), ()), + limit: int | None = None, + offset: int = 0, + include_dimension_records: bool = False, + ) -> list[list[str]]: + self.assertEqual(results.has_dimension_records, include_dimension_records) + types_seen: list[str] = [] + order_by_strings: list[list[str]] = [] + for single_type_results, single_type_order_by in zip(results.by_dataset_type(), order_by): + order_by_strings.append( + check_single_type( + single_type_results, + order_by=single_type_order_by, + limit=limit, + offset=offset, + include_dimension_records=include_dimension_records, + ) + ) + types_seen.append(single_type_results.dataset_type.name) + self.assertEqual(types_seen, sorted(["raw", "bias", "refcat"])) + return order_by_strings + + # Run the closure's tests on variants of the base query. + self.assertEqual(check_chained(results1), [[], [], []]) + self.assertEqual(check_chained(results2), [[], [], []]) + self.assertEqual(check_chained(results3), [[], [], []]) + self.assertEqual( + check_chained(results1.with_dimension_records(), include_dimension_records=True), [[], [], []] + ) + self.assertEqual( + check_chained( + results1.with_dimension_records().with_dimension_records(), include_dimension_records=True + ), + [[], [], []], + ) + self.assertEqual(check_chained(results1, limit=2), [[], [], []]) + self.assertEqual(check_chained(results1, offset=1), [[], [], []]) + self.assertEqual(check_chained(results1, limit=3, offset=3), [[], [], []]) + self.assertEqual( + check_chained( + results1, + order_by=[ + ["bias.timespan.begin"], + ["ingest_date"], + ["htm7"], + ], + ), + [["bias.timespan.begin"], ["raw.ingest_date"], ["htm7"]], + ) + + def test_dataset_iteration(self) -> None: + """Tests for SingleTypeDatasetQueryResult iteration.""" + + def make_ref(n: int) -> DimensionRecord: + return DatasetRef( + self.raw, + DataCoordinate.standardize( + instrument="DummyCam", exposure=4, detector=n, universe=self.universe + ), + run="DummyCam/raw/all", + id=uuid.uuid4(), + ) + + result_rows = ( + [make_ref(n) for n in range(3)], + [make_ref(n) for n in range(3, 6)], + [make_ref(10)], + ) + results = self.query(result_rows=result_rows).datasets("raw") + self.assertEqual(list(results), list(itertools.chain.from_iterable(result_rows))) + + def test_identifiers(self) -> None: + """Test edge-cases of identifiers in order_by expressions.""" + + def extract_order_by(results: DataCoordinateQueryResults) -> list[str]: + with self.assertRaises(_TestQueryExecution) as cm: + list(results) + return [str(term) for term in cm.exception.result_spec.order_by] + + self.assertEqual( + extract_order_by(self.query().data_ids(["visit"]).order_by("-timespan.begin")), + ["visit.timespan.begin DESC"], + ) + self.assertEqual( + extract_order_by(self.query().data_ids(["visit"]).order_by("timespan.end")), + ["visit.timespan.end"], + ) + self.assertEqual( + extract_order_by(self.query().data_ids(["visit"]).order_by("-visit.timespan.begin")), + ["visit.timespan.begin DESC"], + ) + self.assertEqual( + extract_order_by(self.query().data_ids(["visit"]).order_by("visit.timespan.end")), + ["visit.timespan.end"], + ) + self.assertEqual( + extract_order_by(self.query().data_ids(["visit"]).order_by("visit.science_program")), + ["visit.science_program"], + ) + self.assertEqual( + extract_order_by(self.query().data_ids(["visit"]).order_by("visit.id")), + ["visit"], + ) + self.assertEqual( + extract_order_by(self.query().data_ids(["visit"]).order_by("visit.physical_filter")), + ["physical_filter"], + ) + with self.assertRaises(TypeError): + self.query().data_ids(["visit"]).order_by(3) + with self.assertRaises(qt.InvalidQueryError): + self.query().data_ids(["visit"]).order_by("visit.region") + with self.assertRaisesRegex(qt.InvalidQueryError, "Ambiguous"): + self.query().data_ids(["visit", "exposure"]).order_by("timespan.begin") + with self.assertRaisesRegex(qt.InvalidQueryError, "Unrecognized"): + self.query().data_ids(["visit", "exposure"]).order_by("blarg") + with self.assertRaisesRegex(qt.InvalidQueryError, "Unrecognized"): + self.query().data_ids(["visit", "exposure"]).order_by("visit.horse") + with self.assertRaisesRegex(qt.InvalidQueryError, "Unrecognized"): + self.query().data_ids(["visit", "exposure"]).order_by("visit.science_program.monkey") + with self.assertRaisesRegex(qt.InvalidQueryError, "not valid for datasets"): + self.query().datasets("raw").order_by("raw.seq_num") + + def test_invalid_models(self) -> None: + """Test invalid models and combinations of models that cannot be + constructed via the public Query and *QueryResults interfaces. + """ + x = ExpressionFactory(self.universe) + with self.assertRaises(qt.InvalidQueryError): + # QueryTree dimensions do not cover dataset dimensions. + qt.QueryTree( + dimensions=self.universe.conform(["visit"]), + datasets={ + "raw": qt.DatasetSearch( + collections=("DummyCam/raw/all",), + dimensions=self.raw.dimensions.as_group(), + storage_class_name=None, + ) + }, + ) + with self.assertRaises(qt.InvalidQueryError): + # QueryTree dimensions do no cover predicate dimensions. + qt.QueryTree( + dimensions=self.universe.conform(["visit"]), + predicate=(x.detector > 5), + ) + with self.assertRaises(qt.InvalidQueryError): + # Predicate references a dataset not in the QueryTree. + qt.QueryTree( + dimensions=self.universe.conform(["exposure", "detector"]), + predicate=(x["raw"].collection == "bird"), + ) + with self.assertRaises(qt.InvalidQueryError): + # ResultSpec's dimensions are not a subset of the query tree's. + DimensionRecordQueryResults( + _TestQueryDriver(), + qt.QueryTree(dimensions=self.universe.conform(["tract"])), + qrs.DimensionRecordResultSpec(element=self.universe["detector"]), + ) + with self.assertRaises(qt.InvalidQueryError): + # ResultSpec's datasets are not a subset of the query tree's. + SingleTypeDatasetQueryResults( + _TestQueryDriver(), + qt.QueryTree(dimensions=self.raw.dimensions.as_group()), + qrs.DatasetRefResultSpec( + dataset_type_name="raw", + dimensions=self.raw.dimensions.as_group(), + storage_class_name=self.raw.storageClass_name, + find_first=True, + ), + ) + with self.assertRaises(qt.InvalidQueryError): + # ResultSpec's order_by expression is not related to the dimensions + # we're returning. + x = ExpressionFactory(self.universe) + DimensionRecordQueryResults( + _TestQueryDriver(), + qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])), + qrs.DimensionRecordResultSpec( + element=self.universe["detector"], order_by=(x.unwrap(x.visit),) + ), + ) + with self.assertRaises(qt.InvalidQueryError): + # ResultSpec's order_by expression is not related to the datasets + # we're returning. + x = ExpressionFactory(self.universe) + DimensionRecordQueryResults( + _TestQueryDriver(), + qt.QueryTree(dimensions=self.universe.conform(["detector", "visit"])), + qrs.DimensionRecordResultSpec( + element=self.universe["detector"], order_by=(x.unwrap(x["raw"].ingest_date),) + ), + ) + + def test_general_result_spec(self) -> None: + """Tests for GeneralResultSpec. + + Unlike the other ResultSpec objects, we don't have a *QueryResults + class for GeneralResultSpec yet, so we can't use the higher-level + interfaces to test it like we can the others. + """ + a = qrs.GeneralResultSpec( + dimensions=self.universe.conform(["detector"]), + dimension_fields={"detector": {"purpose"}}, + dataset_fields={}, + find_first=False, + ) + self.assertEqual(a.find_first_dataset, None) + a_columns = qt.ColumnSet(self.universe.conform(["detector"])) + a_columns.dimension_fields["detector"].add("purpose") + self.assertEqual(a.get_result_columns(), a_columns) + b = qrs.GeneralResultSpec( + dimensions=self.universe.conform(["detector"]), + dimension_fields={}, + dataset_fields={"bias": {"timespan", "dataset_id"}}, + find_first=True, + ) + self.assertEqual(b.find_first_dataset, "bias") + b_columns = qt.ColumnSet(self.universe.conform(["detector"])) + b_columns.dataset_fields["bias"].add("timespan") + b_columns.dataset_fields["bias"].add("dataset_id") + self.assertEqual(b.get_result_columns(), b_columns) + with self.assertRaises(qt.InvalidQueryError): + # More than one dataset type with find_first + qrs.GeneralResultSpec( + dimensions=self.universe.conform(["detector", "exposure"]), + dimension_fields={}, + dataset_fields={"bias": {"dataset_id"}, "raw": {"dataset_id"}}, + find_first=True, + ) + with self.assertRaises(qt.InvalidQueryError): + # Out-of-bounds dimension fields. + qrs.GeneralResultSpec( + dimensions=self.universe.conform(["detector"]), + dimension_fields={"visit": {"name"}}, + dataset_fields={}, + find_first=False, + ) + with self.assertRaises(qt.InvalidQueryError): + # No fields for dimension element. + qrs.GeneralResultSpec( + dimensions=self.universe.conform(["detector"]), + dimension_fields={"detector": set()}, + dataset_fields={}, + find_first=True, + ) + with self.assertRaises(qt.InvalidQueryError): + # No fields for dataset. + qrs.GeneralResultSpec( + dimensions=self.universe.conform(["detector"]), + dimension_fields={}, + dataset_fields={"bias": set()}, + find_first=True, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_query_utilities.py b/tests/test_query_utilities.py new file mode 100644 index 0000000000..ffaaae790d --- /dev/null +++ b/tests/test_query_utilities.py @@ -0,0 +1,470 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Tests for non-public Butler._query functionality that is not specific to +any Butler or QueryDriver implementation. +""" + +from __future__ import annotations + +import unittest +from typing import Iterable + +import astropy.time +from lsst.daf.butler import DimensionUniverse, Timespan +from lsst.daf.butler.dimensions import DimensionElement, DimensionGroup +from lsst.daf.butler.queries import tree as qt +from lsst.daf.butler.queries.expression_factory import ExpressionFactory +from lsst.daf.butler.queries.overlaps import OverlapsVisitor +from lsst.daf.butler.queries.visitors import PredicateVisitFlags +from lsst.sphgeom import Mq3cPixelization, Region + + +class ColumnSetTestCase(unittest.TestCase): + """Tests for lsst.daf.butler.queries.ColumnSet.""" + + def setUp(self) -> None: + self.universe = DimensionUniverse() + + def test_basics(self) -> None: + columns = qt.ColumnSet(self.universe.conform(["detector"])) + self.assertNotEqual(columns, columns.dimensions.names) # intentionally not comparable to other sets + self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) + self.assertFalse(columns.dataset_fields) + columns.dataset_fields["bias"].add("dataset_id") + self.assertEqual(dict(columns.dataset_fields), {"bias": {"dataset_id"}}) + columns.dimension_fields["detector"].add("purpose") + self.assertEqual(columns.dimension_fields["detector"], {"purpose"}) + self.assertTrue(columns) + self.assertEqual( + list(columns), + [(k, None) for k in columns.dimensions.data_coordinate_keys] + + [("detector", "purpose"), ("bias", "dataset_id")], + ) + self.assertEqual(str(columns), "{instrument, detector, detector:purpose, bias:dataset_id}") + empty = qt.ColumnSet(self.universe.empty.as_group()) + self.assertFalse(empty) + self.assertFalse(columns.issubset(empty)) + self.assertTrue(columns.issuperset(empty)) + self.assertTrue(columns.isdisjoint(empty)) + copy = columns.copy() + self.assertEqual(columns, copy) + self.assertTrue(columns.issubset(copy)) + self.assertTrue(columns.issuperset(copy)) + self.assertFalse(columns.isdisjoint(copy)) + copy.dataset_fields["bias"].add("timespan") + copy.dimension_fields["detector"].add("name") + copy.update_dimensions(self.universe.conform(["band"])) + self.assertEqual(copy.dataset_fields["bias"], {"dataset_id", "timespan"}) + self.assertEqual(columns.dataset_fields["bias"], {"dataset_id"}) + self.assertEqual(copy.dimension_fields["detector"], {"purpose", "name"}) + self.assertEqual(columns.dimension_fields["detector"], {"purpose"}) + self.assertTrue(columns.issubset(copy)) + self.assertFalse(columns.issuperset(copy)) + self.assertFalse(columns.isdisjoint(copy)) + columns.update(copy) + self.assertEqual(columns, copy) + self.assertTrue(columns.is_timespan("visit", "timespan")) + self.assertFalse(columns.is_timespan("visit", None)) + self.assertFalse(columns.is_timespan("detector", "purpose")) + + def test_drop_dimension_keys(self): + columns = qt.ColumnSet(self.universe.conform(["physical_filter"])) + columns.drop_implied_dimension_keys() + self.assertEqual(list(columns), [("instrument", None), ("physical_filter", None)]) + undropped = qt.ColumnSet(columns.dimensions) + self.assertTrue(columns.issubset(undropped)) + self.assertFalse(columns.issuperset(undropped)) + self.assertFalse(columns.isdisjoint(undropped)) + band_only = qt.ColumnSet(self.universe.conform(["band"])) + self.assertFalse(columns.issubset(band_only)) + self.assertFalse(columns.issuperset(band_only)) + self.assertTrue(columns.isdisjoint(band_only)) + copy = columns.copy() + copy.update(band_only) + self.assertEqual(copy, undropped) + columns.restore_dimension_keys() + self.assertEqual(columns, undropped) + + def test_get_column_spec(self) -> None: + columns = qt.ColumnSet(self.universe.conform(["detector"])) + columns.dimension_fields["detector"].add("purpose") + columns.dataset_fields["bias"].update(["dataset_id", "run", "collection", "timespan", "ingest_date"]) + self.assertEqual(columns.get_column_spec("instrument", None).name, "instrument") + self.assertEqual(columns.get_column_spec("instrument", None).type, "string") + self.assertEqual(columns.get_column_spec("instrument", None).nullable, False) + self.assertEqual(columns.get_column_spec("detector", None).name, "detector") + self.assertEqual(columns.get_column_spec("detector", None).type, "int") + self.assertEqual(columns.get_column_spec("detector", None).nullable, False) + self.assertEqual(columns.get_column_spec("detector", "purpose").name, "detector:purpose") + self.assertEqual(columns.get_column_spec("detector", "purpose").type, "string") + self.assertEqual(columns.get_column_spec("detector", "purpose").nullable, True) + self.assertEqual(columns.get_column_spec("bias", "dataset_id").name, "bias:dataset_id") + self.assertEqual(columns.get_column_spec("bias", "dataset_id").type, "uuid") + self.assertEqual(columns.get_column_spec("bias", "dataset_id").nullable, False) + self.assertEqual(columns.get_column_spec("bias", "run").name, "bias:run") + self.assertEqual(columns.get_column_spec("bias", "run").type, "string") + self.assertEqual(columns.get_column_spec("bias", "run").nullable, False) + self.assertEqual(columns.get_column_spec("bias", "collection").name, "bias:collection") + self.assertEqual(columns.get_column_spec("bias", "collection").type, "string") + self.assertEqual(columns.get_column_spec("bias", "collection").nullable, False) + self.assertEqual(columns.get_column_spec("bias", "timespan").name, "bias:timespan") + self.assertEqual(columns.get_column_spec("bias", "timespan").type, "timespan") + self.assertEqual(columns.get_column_spec("bias", "timespan").nullable, False) + self.assertEqual(columns.get_column_spec("bias", "ingest_date").name, "bias:ingest_date") + self.assertEqual(columns.get_column_spec("bias", "ingest_date").type, "datetime") + self.assertEqual(columns.get_column_spec("bias", "ingest_date").nullable, True) + + +class _RecordingOverlapsVisitor(OverlapsVisitor): + def __init__(self, dimensions: DimensionGroup): + super().__init__(dimensions) + self.spatial_constraints: list[tuple[str, PredicateVisitFlags]] = [] + self.spatial_joins: list[tuple[str, str, PredicateVisitFlags]] = [] + self.temporal_dimension_joins: list[tuple[str, str, PredicateVisitFlags]] = [] + + def visit_spatial_constraint( + self, element: DimensionElement, region: Region, flags: PredicateVisitFlags + ) -> qt.Predicate | None: + self.spatial_constraints.append((element.name, flags)) + return super().visit_spatial_constraint(element, region, flags) + + def visit_spatial_join( + self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags + ) -> qt.Predicate | None: + self.spatial_joins.append((a.name, b.name, flags)) + return super().visit_spatial_join(a, b, flags) + + def visit_temporal_dimension_join( + self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags + ) -> qt.Predicate | None: + self.temporal_dimension_joins.append((a.name, b.name, flags)) + return super().visit_temporal_dimension_join(a, b, flags) + + +class OverlapsVisitorTestCase(unittest.TestCase): + """Tests for lsst.daf.butler.queries.overlaps.OverlapsVisitor, which is + responsible for validating and inferring spatial and temporal joins and + constraints. + """ + + def setUp(self) -> None: + self.universe = DimensionUniverse() + + def run_visitor( + self, + dimensions: Iterable[str], + predicate: qt.Predicate, + expected: str | None = None, + join_operands: Iterable[DimensionGroup] = (), + ) -> _RecordingOverlapsVisitor: + visitor = _RecordingOverlapsVisitor(self.universe.conform(dimensions)) + if expected is None: + expected = str(predicate) + new_predicate = visitor.run(predicate, join_operands=join_operands) + self.assertEqual(str(new_predicate), expected) + return visitor + + def test_trivial(self) -> None: + """Test the overlaps visitor when there is nothing spatial or temporal + in the query at all. + """ + x = ExpressionFactory(self.universe) + # Trivial predicate. + visitor = self.run_visitor(["physical_filter"], qt.Predicate.from_bool(True)) + self.assertFalse(visitor.spatial_joins) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # Non-overlap predicate. + visitor = self.run_visitor(["physical_filter"], x.any(x.band == "r", x.band == "i")) + self.assertFalse(visitor.spatial_joins) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + + def test_one_spatial_family(self) -> None: + """Test the overlaps visitor when there is one spatial family.""" + x = ExpressionFactory(self.universe) + pixelization = Mq3cPixelization(10) + region = pixelization.quad(12058870) + # Trivial predicate. + visitor = self.run_visitor(["visit"], qt.Predicate.from_bool(True)) + self.assertFalse(visitor.spatial_joins) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # Non-overlap predicate. + visitor = self.run_visitor(["visit"], x.any(x.band == "r", x.visit > 2)) + self.assertFalse(visitor.spatial_joins) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # Spatial constraint predicate, in various positions relative to other + # non-overlap predicates. + visitor = self.run_visitor(["visit"], x.visit.region.overlaps(region)) + self.assertEqual(visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags(0))]) + visitor = self.run_visitor(["visit"], x.all(x.visit.region.overlaps(region), x.band == "r")) + self.assertEqual( + visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags.HAS_AND_SIBLINGS)] + ) + visitor = self.run_visitor(["visit"], x.any(x.visit.region.overlaps(region), x.band == "r")) + self.assertEqual( + visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags.HAS_OR_SIBLINGS)] + ) + visitor = self.run_visitor( + ["visit"], + x.all( + x.any(x.literal(region).overlaps(x.visit.region), x.band == "r"), + x.visit.observation_reason == "science", + ), + ) + self.assertEqual( + visitor.spatial_constraints, + [ + ( + self.universe["visit"], + PredicateVisitFlags.HAS_OR_SIBLINGS | PredicateVisitFlags.HAS_AND_SIBLINGS, + ) + ], + ) + visitor = self.run_visitor( + ["visit"], + x.any( + x.all(x.visit.region.overlaps(region), x.band == "r"), + x.visit.observation_reason == "science", + ), + ) + self.assertEqual( + visitor.spatial_constraints, + [ + ( + self.universe["visit"], + PredicateVisitFlags.HAS_OR_SIBLINGS | PredicateVisitFlags.HAS_AND_SIBLINGS, + ) + ], + ) + # A spatial join between dimensions in the same family is an error. + with self.assertRaises(qt.InvalidQueryError): + self.run_visitor(["patch", "tract"], x.patch.region.overlaps(x.tract.region)) + + def test_single_unambiguous_spatial_join(self) -> None: + """Test the overlaps visitor when there are two spatial families with + one dimension element in each, and hence exactly one join is needed. + """ + x = ExpressionFactory(self.universe) + # Trivial predicate; an automatic join is added. Order of elements in + # automatic joins is lexicographical in order to be deterministic. + visitor = self.run_visitor( + ["visit", "tract"], qt.Predicate.from_bool(True), "tract.region OVERLAPS visit.region" + ) + self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS)]) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # Non-overlap predicate; an automatic join is added. + visitor = self.run_visitor( + ["visit", "tract"], + x.all(x.band == "r", x.visit > 2), + "band == 'r' AND visit > 2 AND tract.region OVERLAPS visit.region", + ) + self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS)]) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # The same overlap predicate that would be added automatically has been + # added manually. + visitor = self.run_visitor( + ["visit", "tract"], + x.tract.region.overlaps(x.visit.region), + "tract.region OVERLAPS visit.region", + ) + self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags(0))]) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # Add the join overlap predicate in an OR expression, which is unusual + # but enough to block the addition of an automatic join; we assume the + # user knows what they're doing. + visitor = self.run_visitor( + ["visit", "tract"], + x.any(x.visit > 2, x.tract.region.overlaps(x.visit.region)), + "visit > 2 OR tract.region OVERLAPS visit.region", + ) + self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_OR_SIBLINGS)]) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # Add the join overlap predicate in a NOT expression, which is unusual + # but permitted in the same sense as OR expressions. + visitor = self.run_visitor( + ["visit", "tract"], + x.not_(x.tract.region.overlaps(x.visit.region)), + "NOT tract.region OVERLAPS visit.region", + ) + self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.INVERTED)]) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # Add a "join operand" whose dimensions include both spatial families. + # This blocks an automatic join from being created, because we assume + # that join operand (e.g. a materialization or dataset search) already + # encodes some spatial join. + visitor = self.run_visitor( + ["visit", "tract"], + qt.Predicate.from_bool(True), + "True", + join_operands=[self.universe.conform(["tract", "visit"])], + ) + self.assertFalse(visitor.spatial_joins) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + + def test_single_flexible_spatial_join(self) -> None: + """Test the overlaps visitor when there are two spatial families and + one has multiple dimension elements. + """ + x = ExpressionFactory(self.universe) + # Trivial predicate; an automatic join between the fine-grained + # elements is added. Order of elements in automatic joins is + # lexicographical in order to be deterministic. + visitor = self.run_visitor( + ["visit", "detector", "patch"], + qt.Predicate.from_bool(True), + "patch.region OVERLAPS visit_detector_region.region", + ) + self.assertEqual( + visitor.spatial_joins, [("patch", "visit_detector_region", PredicateVisitFlags.HAS_AND_SIBLINGS)] + ) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # The same overlap predicate that would be added automatically has been + # added manually. + visitor = self.run_visitor( + ["visit", "detector", "patch"], + x.patch.region.overlaps(x.visit_detector_region.region), + "patch.region OVERLAPS visit_detector_region.region", + ) + self.assertEqual(visitor.spatial_joins, [("patch", "visit_detector_region", PredicateVisitFlags(0))]) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # A coarse overlap join has been added; respect it and do not add an + # automatic one. + visitor = self.run_visitor( + ["visit", "detector", "patch"], + x.tract.region.overlaps(x.visit.region), + "tract.region OVERLAPS visit.region", + ) + self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags(0))]) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # Add a "join operand" whose dimensions include both spatial families. + # This blocks an automatic join from being created, because we assume + # that join operand (e.g. a materialization or dataset search) already + # encodes some spatial join. + visitor = self.run_visitor( + ["visit", "detector", "patch"], + qt.Predicate.from_bool(True), + "True", + join_operands=[self.universe.conform(["tract", "visit_detector_region"])], + ) + self.assertFalse(visitor.spatial_joins) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + + def test_multiple_spatial_joins(self) -> None: + """Test the overlaps visitor when there are >2 spatial families.""" + x = ExpressionFactory(self.universe) + # Trivial predicate. This is an error, because we cannot generate + # automatic spatial joins when there are more than two families + with self.assertRaises(qt.InvalidQueryError): + self.run_visitor(["visit", "patch", "htm7"], qt.Predicate.from_bool(True)) + # Predicate that joins one pair of families but orphans the the other; + # also an error. + with self.assertRaises(qt.InvalidQueryError): + self.run_visitor(["visit", "patch", "htm7"], x.visit.region.overlaps(x.htm7.region)) + # A sufficient overlap join predicate has been added; each family is + # connected to at least one other. + visitor = self.run_visitor( + ["visit", "patch", "htm7"], + x.all(x.tract.region.overlaps(x.visit.region), x.tract.region.overlaps(x.htm7.region)), + "tract.region OVERLAPS visit.region AND tract.region OVERLAPS htm7.region", + ) + self.assertEqual( + visitor.spatial_joins, + [ + ("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS), + ("tract", "htm7", PredicateVisitFlags.HAS_AND_SIBLINGS), + ], + ) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # Add a "join operand" whose dimensions includes two spatial families, + # with a predicate that joins the third in. + visitor = self.run_visitor( + ["visit", "patch", "htm7"], + x.tract.region.overlaps(x.htm7.region), + "tract.region OVERLAPS htm7.region", + join_operands=[self.universe.conform(["visit", "tract"])], + ) + self.assertEqual( + visitor.spatial_joins, + [ + ("tract", "htm7", PredicateVisitFlags(0)), + ], + ) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + + def test_one_temporal_family(self) -> None: + """Test the overlaps visitor when there is one temporal family.""" + x = ExpressionFactory(self.universe) + begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") + end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") + timespan = Timespan(begin, end) + # Trivial predicate. + visitor = self.run_visitor(["exposure"], qt.Predicate.from_bool(True)) + self.assertFalse(visitor.spatial_joins) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # Non-overlap predicate. + visitor = self.run_visitor(["exposure"], x.any(x.band == "r", x.exposure > 2)) + self.assertFalse(visitor.spatial_joins) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # Temporal constraint predicate. + visitor = self.run_visitor(["exposure"], x.exposure.timespan.overlaps(timespan)) + self.assertFalse(visitor.spatial_joins) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + # A temporal join between dimensions in the same family is an error. + with self.assertRaises(qt.InvalidQueryError): + self.run_visitor(["exposure", "visit"], x.exposure.timespan.overlaps(x.visit.timespan)) + # Overlap join with a calibration dataset's validity ranges. + visitor = self.run_visitor(["exposure"], x.exposure.timespan.overlaps(x["bias"].timespan)) + self.assertFalse(visitor.spatial_joins) + self.assertFalse(visitor.spatial_constraints) + self.assertFalse(visitor.temporal_dimension_joins) + + # There are no tests for temporal dimension joins, because the default + # dimension universe only has one spatial family, and the untested logic + # trivially duplicates the spatial-join logic. + + +if __name__ == "__main__": + unittest.main()