diff --git a/python/lsst/daf/butler/_query_results.py b/python/lsst/daf/butler/_query_results.py index 30c18327e9..fe35561a04 100644 --- a/python/lsst/daf/butler/_query_results.py +++ b/python/lsst/daf/butler/_query_results.py @@ -562,6 +562,11 @@ def dataset_type(self) -> DatasetType: """ raise NotImplementedError() + @property + def dimensions(self) -> DimensionGroup: + """The dimensions of the dataset type returned by this query.""" + return self.dataset_type.dimensions.as_group() + @property @abstractmethod def data_ids(self) -> DataCoordinateQueryResults: diff --git a/python/lsst/daf/butler/queries/__init__.py b/python/lsst/daf/butler/queries/__init__.py new file mode 100644 index 0000000000..15743f291f --- /dev/null +++ b/python/lsst/daf/butler/queries/__init__.py @@ -0,0 +1,32 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from ._base import * +from ._data_coordinate_query_results import * +from ._dataset_query_results import * +from ._dimension_record_query_results import * +from ._query import * diff --git a/python/lsst/daf/butler/queries/_base.py b/python/lsst/daf/butler/queries/_base.py new file mode 100644 index 0000000000..a32a1b2967 --- /dev/null +++ b/python/lsst/daf/butler/queries/_base.py @@ -0,0 +1,195 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("QueryBase", "HomogeneousQueryBase", "CountableQueryBase", "QueryResultsBase") + +from abc import ABC, abstractmethod +from collections.abc import Iterable, Set +from typing import Any, Self + +from ..dimensions import DimensionGroup +from .convert_args import convert_order_by_args +from .driver import QueryDriver +from .expression_factory import ExpressionProxy +from .tree import OrderExpression, QueryTree + + +class QueryBase(ABC): + @abstractmethod + def any(self, *, execute: bool = True, exact: bool = True) -> bool: + """Test whether the query would return any rows. + + Parameters + ---------- + execute : `bool`, optional + If `True`, execute at least a ``LIMIT 1`` query if it cannot be + determined prior to execution that the query would return no rows. + exact : `bool`, optional + If `True`, run the full query and perform post-query filtering if + needed, until at least one result row is found. If `False`, the + returned result does not account for post-query filtering, and + hence may be `True` even when all result rows would be filtered + out. + + Returns + ------- + any : `bool` + `True` if the query would (or might, depending on arguments) yield + result rows. `False` if it definitely would not. + """ + raise NotImplementedError() + + @abstractmethod + def explain_no_results(self, execute: bool = True) -> Iterable[str]: + """Return human-readable messages that may help explain why the query + yields no results. + + Parameters + ---------- + execute : `bool`, optional + If `True` (default) execute simplified versions (e.g. ``LIMIT 1``) + of aspects of the tree to more precisely determine where rows were + filtered out. + + Returns + ------- + messages : `~collections.abc.Iterable` [ `str` ] + String messages that describe reasons the query might not yield any + results. + """ + raise NotImplementedError() + + +class HomogeneousQueryBase(QueryBase): + def __init__(self, driver: QueryDriver, tree: QueryTree): + self._driver = driver + self._tree = tree + + @property + def dimensions(self) -> DimensionGroup: + """All dimensions included in the query's columns.""" + return self._tree.dimensions + + def any(self, *, execute: bool = True, exact: bool = True) -> bool: + # Docstring inherited. + return self._driver.any(self._tree, execute=execute, exact=exact) + + def explain_no_results(self, execute: bool = True) -> Iterable[str]: + # Docstring inherited. + return self._driver.explain_no_results(self._tree, execute=execute) + + +class CountableQueryBase(QueryBase): + @abstractmethod + def count(self, *, exact: bool = True, discard: bool = False) -> int: + """Count the number of rows this query would return. + + Parameters + ---------- + exact : `bool`, optional + If `True`, run the full query and perform post-query filtering if + needed to account for that filtering in the count. If `False`, the + result may be an upper bound. + discard : `bool`, optional + If `True`, compute the exact count even if it would require running + the full query and then throwing away the result rows after + counting them. If `False`, this is an error, as the user would + usually be better off executing the query first to fetch its rows + into a new query (or passing ``exact=False``). Ignored if + ``exact=False``. + + Returns + ------- + count : `int` + The number of rows the query would return, or an upper bound if + ``exact=False``. + """ + raise NotImplementedError() + + +class QueryResultsBase(HomogeneousQueryBase, CountableQueryBase): + def order_by(self, *args: str | OrderExpression | ExpressionProxy) -> Self: + """Return a new query that yields ordered results. + + Parameters + ---------- + *args : `str` + Names of the columns/dimensions to use for ordering. Column name + can be prefixed with minus (``-``) to use descending ordering. + + Returns + ------- + result : `QueryResultsBase` + An ordered version of this query results object. + + Notes + ----- + If this method is called multiple times, the new sort terms replace + the old ones. + """ + return self._copy( + self._tree, order_by=convert_order_by_args(self.dimensions, self._get_datasets(), *args) + ) + + def limit(self, limit: int | None = None, offset: int = 0) -> Self: + """Return a new query that slices its result rows positionally. + + Parameters + ---------- + limit : `int` or `None`, optional + Upper limit on the number of returned records. + offset : `int`, optional + The number of records to skip before returning at most ``limit`` + records. + + Returns + ------- + result : `QueryResultsBase` + A sliced version of this query results object. + + Notes + ----- + If this method is called multiple times, the new slice parameters + replace the old ones. Slicing always occurs after sorting, even if + `limit` is called before `order_by`. + """ + return self._copy(self._tree, limit=limit, offset=offset) + + @abstractmethod + def _get_datasets(self) -> Set[str]: + """Return all dataset types included in the query's result rows.""" + raise NotImplementedError() + + @abstractmethod + def _copy(self, tree: QueryTree, **kwargs: Any) -> Self: + """Return a modified copy of ``self``. + + Modifications should be validated, not assumed to be correct. + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/queries/_data_coordinate_query_results.py b/python/lsst/daf/butler/queries/_data_coordinate_query_results.py new file mode 100644 index 0000000000..5e39ccc9b2 --- /dev/null +++ b/python/lsst/daf/butler/queries/_data_coordinate_query_results.py @@ -0,0 +1,142 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("DataCoordinateQueryResults",) + +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING, Any + +from ..dimensions import DataCoordinate, DimensionGroup +from ._base import QueryResultsBase +from .driver import QueryDriver +from .tree import InvalidQueryTreeError, QueryTree + +if TYPE_CHECKING: + from .result_specs import DataCoordinateResultSpec + + +class DataCoordinateQueryResults(QueryResultsBase): + """A method-chaining builder for butler queries that return data IDs. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. + The instance returned directly by the `Butler._query` entry point + should be constructed via `make_unit_query_tree`. + spec : `DataCoordinateResultSpec` + Specification of the query result rows, including output columns, + ordering, and slicing. + + Notes + ----- + This refines the `DataCoordinateQueryResults` ABC defined in + `lsst.daf.butler._query_results`, but the intent is to replace that ABC + with this concrete class, rather than inherit from it. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree, spec: DataCoordinateResultSpec): + spec.validate_tree(tree) + super().__init__(driver, tree) + self._spec = spec + + def __iter__(self) -> Iterator[DataCoordinate]: + page = self._driver.execute(self._spec, self._tree) + yield from page.rows + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield from page.rows + + @property + def has_dimension_records(self) -> bool: + """Whether all data IDs in this iterable contain dimension records.""" + return self._spec.include_dimension_records + + def with_dimension_records(self) -> DataCoordinateQueryResults: + """Return a results object for which `has_dimension_records` is + `True`. + """ + if self.has_dimension_records: + return self + return self._copy(tree=self._tree, include_dimension_records=True) + + def subset( + self, + dimensions: DimensionGroup | Iterable[str] | None = None, + ) -> DataCoordinateQueryResults: + """Return a results object containing a subset of the dimensions of + this one. + + Parameters + ---------- + dimensions : `DimensionGroup` or \ + `~collections.abc.Iterable` [ `str`], optional + Dimensions to include in the new results object. If `None`, + ``self.dimensions`` is used. + + Returns + ------- + results : `DataCoordinateQueryResults` + A results object corresponding to the given criteria. May be + ``self`` if it already qualifies. + + Raises + ------ + InvalidQueryTreeError + Raised when ``dimensions`` is not a subset of the dimensions in + this result. + """ + if dimensions is None: + dimensions = self.dimensions + else: + dimensions = self._driver.universe.conform(dimensions) + if not dimensions <= self.dimensions: + raise InvalidQueryTreeError( + f"New dimensions {dimensions} are not a subset of the current " + f"dimensions {self.dimensions}." + ) + return self._copy(tree=self._tree, dimensions=dimensions) + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + # Docstring inherited. + return self._driver.count( + self._tree, + self._spec.get_result_columns(), + find_first_dataset=None, + exact=exact, + discard=discard, + ) + + def _copy(self, tree: QueryTree, **kwargs: Any) -> DataCoordinateQueryResults: + return DataCoordinateQueryResults(self._driver, tree, spec=self._spec.model_copy(update=kwargs)) + + def _get_datasets(self) -> frozenset[str]: + return frozenset() diff --git a/python/lsst/daf/butler/queries/_dataset_query_results.py b/python/lsst/daf/butler/queries/_dataset_query_results.py new file mode 100644 index 0000000000..ecd33e99e1 --- /dev/null +++ b/python/lsst/daf/butler/queries/_dataset_query_results.py @@ -0,0 +1,232 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "DatasetQueryResults", + "ChainedDatasetQueryResults", + "SingleTypeDatasetQueryResults", +) + +import itertools +from abc import abstractmethod +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING, Any + +from .._dataset_ref import DatasetRef +from .._dataset_type import DatasetType +from ._base import CountableQueryBase, QueryResultsBase +from .driver import QueryDriver +from .result_specs import DatasetRefResultSpec +from .tree import QueryTree + +if TYPE_CHECKING: + from ._data_coordinate_query_results import DataCoordinateQueryResults + + +class DatasetQueryResults(CountableQueryBase, Iterable[DatasetRef]): + """An interface for objects that represent the results of queries for + datasets. + """ + + @abstractmethod + def by_dataset_type(self) -> Iterator[SingleTypeDatasetQueryResults]: + """Group results by dataset type. + + Returns + ------- + iter : `~collections.abc.Iterator` [ `SingleTypeDatasetQueryResults` ] + An iterator over `DatasetQueryResults` instances that are each + responsible for a single dataset type. + """ + raise NotImplementedError() + + @property + @abstractmethod + def has_dimension_records(self) -> bool: + """Whether all data IDs in this iterable contain dimension records.""" + raise NotImplementedError() + + @abstractmethod + def with_dimension_records(self) -> DatasetQueryResults: + """Return a results object for which `has_dimension_records` is + `True`. + """ + raise NotImplementedError() + + +class SingleTypeDatasetQueryResults(DatasetQueryResults, QueryResultsBase): + """A method-chaining builder for butler queries that return `DatasetRef` + objects. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. + The instance returned directly by the `Butler._query` entry point + should be constructed via `make_unit_query_tree`. + spec : `DatasetRefResultSpec` + Specification of the query result rows, including output columns, + ordering, and slicing. + + Notes + ----- + This refines the `SingleTypeDatasetQueryResults` ABC defined in + `lsst.daf.butler._query_results`, but the intent is to replace that ABC + with this concrete class, rather than inherit from it. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree, spec: DatasetRefResultSpec): + spec.validate_tree(tree) + super().__init__(driver, tree) + self._spec = spec + + def __iter__(self) -> Iterator[DatasetRef]: + page = self._driver.execute(self._spec, self._tree) + yield from page.rows + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield from page.rows + + @property + def dataset_type(self) -> DatasetType: + # Docstring inherited. + return DatasetType(self._spec.dataset_type_name, self._spec.dimensions, self._spec.storage_class_name) + + @property + def data_ids(self) -> DataCoordinateQueryResults: + # Docstring inherited. + from ._data_coordinate_query_results import DataCoordinateQueryResults, DataCoordinateResultSpec + + return DataCoordinateQueryResults( + self._driver, + tree=self._tree, + spec=DataCoordinateResultSpec.model_construct( + dimensions=self.dataset_type.dimensions.as_group(), + include_dimension_records=self._spec.include_dimension_records, + ), + ) + + @property + def has_dimension_records(self) -> bool: + # Docstring inherited. + return self._spec.include_dimension_records + + def with_dimension_records(self) -> SingleTypeDatasetQueryResults: + # Docstring inherited. + if self.has_dimension_records: + return self + return self._copy(tree=self._tree, include_dimension_records=True) + + def by_dataset_type(self) -> Iterator[SingleTypeDatasetQueryResults]: + # Docstring inherited. + return iter((self,)) + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + # Docstring inherited. + return self._driver.count( + self._tree, + self._spec.get_result_columns(), + find_first_dataset=self._spec.find_first_dataset, + exact=exact, + discard=discard, + ) + + def _copy(self, tree: QueryTree, **kwargs: Any) -> SingleTypeDatasetQueryResults: + return SingleTypeDatasetQueryResults( + self._driver, + self._tree, + self._spec.model_copy(update=kwargs), + ) + + def _get_datasets(self) -> frozenset[str]: + return frozenset({self.dataset_type.name}) + + +class ChainedDatasetQueryResults(DatasetQueryResults): + """Implementation of `DatasetQueryResults` that delegates to a sequence + of `SingleTypeDatasetQueryResults`. + + Parameters + ---------- + by_dataset_type : `tuple` [ `SingleTypeDatasetQueryResults` ] + Tuple of single-dataset-type query result objects to combine. + + Notes + ----- + Ideally this will eventually just be "DatasetQueryResults", because we + won't need an ABC if this is the only implementation. + """ + + def __init__(self, by_dataset_type: tuple[SingleTypeDatasetQueryResults, ...]): + self._by_dataset_type = by_dataset_type + + def __iter__(self) -> Iterator[DatasetRef]: + return itertools.chain.from_iterable(self._by_dataset_type) + + def by_dataset_type(self) -> Iterator[SingleTypeDatasetQueryResults]: + # Docstring inherited. + return iter(self._by_dataset_type) + + @property + def has_dimension_records(self) -> bool: + # Docstring inherited. + return all(single_type_results.has_dimension_records for single_type_results in self._by_dataset_type) + + def with_dimension_records(self) -> ChainedDatasetQueryResults: + # Docstring inherited. + return ChainedDatasetQueryResults( + tuple( + [ + single_type_results.with_dimension_records() + for single_type_results in self._by_dataset_type + ] + ) + ) + + def any(self, *, execute: bool = True, exact: bool = True) -> bool: + # Docstring inherited. + return any( + single_type_results.any(execute=execute, exact=exact) + for single_type_results in self._by_dataset_type + ) + + def explain_no_results(self, execute: bool = True) -> Iterable[str]: + # Docstring inherited. + messages: list[str] = [] + for single_type_results in self._by_dataset_type: + messages.extend(single_type_results.explain_no_results(execute=execute)) + return messages + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + return sum( + single_type_results.count(exact=exact, discard=discard) + for single_type_results in self._by_dataset_type + ) diff --git a/python/lsst/daf/butler/queries/_dimension_record_query_results.py b/python/lsst/daf/butler/queries/_dimension_record_query_results.py new file mode 100644 index 0000000000..6663fcc0af --- /dev/null +++ b/python/lsst/daf/butler/queries/_dimension_record_query_results.py @@ -0,0 +1,109 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("DimensionRecordQueryResults",) + +from collections.abc import Iterator +from typing import Any + +from ..dimensions import DimensionElement, DimensionRecord, DimensionRecordSet, DimensionRecordTable +from ._base import QueryResultsBase +from .driver import QueryDriver +from .result_specs import DimensionRecordResultSpec +from .tree import QueryTree + + +class DimensionRecordQueryResults(QueryResultsBase): + """A method-chaining builder for butler queries that return data IDs. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. + The instance returned directly by the `Butler._query` entry point + should be constructed via `make_unit_query_tree`. + spec : `DimensionRecordResultSpec` + Specification of the query result rows, including output columns, + ordering, and slicing. + + Notes + ----- + This refines the `DimensionRecordQueryResults` ABC defined in + `lsst.daf.butler._query_results`, but the intent is to replace that ABC + with this concrete class, rather than inherit from it. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree, spec: DimensionRecordResultSpec): + spec.validate_tree(tree) + super().__init__(driver, tree) + self._spec = spec + + def __iter__(self) -> Iterator[DimensionRecord]: + page = self._driver.execute(self._spec, self._tree) + yield from page.rows + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield from page.rows + + def iter_table_pages(self) -> Iterator[DimensionRecordTable]: + page = self._driver.execute(self._spec, self._tree) + yield page.as_table() + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield page.as_table() + + def iter_set_pages(self) -> Iterator[DimensionRecordSet]: + page = self._driver.execute(self._spec, self._tree) + yield page.as_set() + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield page.as_set() + + @property + def element(self) -> DimensionElement: + # Docstring inherited. + return self._spec.element + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + # Docstring inherited. + return self._driver.count( + self._tree, + self._spec.get_result_columns(), + find_first_dataset=None, + exact=exact, + discard=discard, + ) + + def _copy(self, tree: QueryTree, **kwargs: Any) -> DimensionRecordQueryResults: + return DimensionRecordQueryResults(self._driver, tree, self._spec.model_copy(update=kwargs)) + + def _get_datasets(self) -> frozenset[str]: + return frozenset() diff --git a/python/lsst/daf/butler/queries/_query.py b/python/lsst/daf/butler/queries/_query.py new file mode 100644 index 0000000000..c8ee728a87 --- /dev/null +++ b/python/lsst/daf/butler/queries/_query.py @@ -0,0 +1,511 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("Query",) + +from collections.abc import Iterable, Mapping, Set +from types import EllipsisType +from typing import Any, overload + +from lsst.utils.iteration import ensure_iterable + +from .._dataset_type import DatasetType +from ..dimensions import DataCoordinate, DataId, DataIdValue, DimensionGroup +from ..registry import DatasetTypeError, MissingDatasetTypeError +from ._base import HomogeneousQueryBase +from ._data_coordinate_query_results import DataCoordinateQueryResults +from ._dataset_query_results import ( + ChainedDatasetQueryResults, + DatasetQueryResults, + SingleTypeDatasetQueryResults, +) +from ._dimension_record_query_results import DimensionRecordQueryResults +from .convert_args import convert_where_args +from .driver import QueryDriver +from .expression_factory import ExpressionFactory +from .result_specs import DataCoordinateResultSpec, DatasetRefResultSpec, DimensionRecordResultSpec +from .tree import ( + DatasetSearch, + InvalidQueryTreeError, + Predicate, + QueryTree, + make_dimension_query_tree, + make_unit_query_tree, +) + + +class Query(HomogeneousQueryBase): + """A method-chaining builder for butler queries. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. + The instance returned directly by the `Butler._query` entry point + should be constructed via `make_unit_query_tree`. + + Notes + ----- + This largely mimics and considerably expands the `Query` ABC defined in + `lsst.daf.butler._query`, but the intent is to replace that ABC with this + concrete class, rather than inherit from it. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree): + super().__init__(driver, tree) + + @property + def dataset_types(self) -> Set[str]: + """The names of all dataset types joined into the query. + + These dataset types are usable in 'where' expressions, but may or may + not be available to result rows. + """ + return self._tree.datasets.keys() + + @property + def expression_factory(self) -> ExpressionFactory: + """A factory for column expressions using overloaded operators. + + Notes + ----- + Typically this attribute will be assigned to a single-character local + variable, and then its (dynamic) attributes can be used to obtain + references to columns that can be included in a query:: + + with butler._query() as query: + x = query.expression_factory + query = query.where( + x.instrument == "LSSTCam", + x.visit.day_obs > 20240701, + x.any(x.band == 'u', x.band == 'y'), + ) + + As shown above, the returned object also has an `any` method to create + combine expressions with logical OR (as well as `not_` and `all`, + though the latter is rarely necessary since `where` already combines + its arguments with AND). + + Proxies for fields associated with dataset types (``dataset_id``, + ``ingest_date``, ``run``, ``collection``, as well as ``timespan`` for + `~CollectionType.CALIBRATION` collection searches) can be obtained with + dict-like access instead:: + + with butler._query() as query: + query = query.order_by(x["raw"].ingest_date) + + Expression proxy objects that correspond to scalar columns overload the + standard comparison operators (``==``, ``!=``, ``<``, ``>``, ``<=``, + ``>=``) and provide `~ScalarExpressionProxy.in_range`, + `~ScalarExpressionProxy.in_iterable`, and + `~ScalarExpressionProxy.in_query` methods for membership tests. For + `order_by` contexts, they also have a `~ScalarExpressionProxy.desc` + property to indicate that the sort order for that expression should be + reversed. + + Proxy objects for region and timespan fields have an `overlaps` method, + and timespans also have `~TimespanProxy.begin` and `~TimespanProxy.end` + properties to access scalar expression proxies for the bounds. + + All proxy objects also have a `~ExpressionProxy.is_null` property. + + Literal values can be created by calling `ExpressionFactory.literal`, + but can almost always be created implicitly via overloaded operators + instead. + """ + return ExpressionFactory(self._driver.universe) + + def data_ids( + self, + dimensions: DimensionGroup | Iterable[str] | str, + ) -> DataCoordinateQueryResults: + """Query for data IDs matching user-provided criteria. + + Parameters + ---------- + dimensions : `DimensionGroup`, `str`, or \ + `~collections.abc.Iterable` [`str`] + The dimensions of the data IDs to yield, as either `DimensionGroup` + instances or `str`. Will be automatically expanded to a complete + `DimensionGroup`. + + Returns + ------- + dataIds : `DataCoordinateQueryResults` + Data IDs matching the given query parameters. These are guaranteed + to identify all dimensions (`DataCoordinate.hasFull` returns + `True`), but will not contain `DimensionRecord` objects + (`DataCoordinate.hasRecords` returns `False`). Call + `~DataCoordinateQueryResults.with_dimension_records` on the + returned object to fetch those. + """ + dimensions = self._driver.universe.conform(dimensions) + tree = self._tree + if not dimensions >= self._tree.dimensions: + tree = tree.join(make_dimension_query_tree(dimensions)) + result_spec = DataCoordinateResultSpec(dimensions=dimensions, include_dimension_records=False) + return DataCoordinateQueryResults(self._driver, tree, result_spec) + + @overload + def datasets( + self, + dataset_type: str | DatasetType, + collections: str | Iterable[str] | None = None, + *, + find_first: bool = True, + ) -> SingleTypeDatasetQueryResults: ... + + @overload + def datasets( + self, + dataset_type: Iterable[str | DatasetType] | EllipsisType, + collections: str | Iterable[str] | None = None, + *, + find_first: bool = True, + ) -> DatasetQueryResults: ... + + def datasets( + self, + dataset_type: str | DatasetType | Iterable[str | DatasetType] | EllipsisType, + collections: str | Iterable[str] | None = None, + *, + find_first: bool = True, + ) -> DatasetQueryResults: + """Query for and iterate over dataset references matching user-provided + criteria. + + Parameters + ---------- + dataset_type : `str`, `DatasetType`, \ + `~collections.abc.Iterable` [ `str` or `DatasetType` ], \ + or ``...`` + The dataset type or types to search for. Passing ``...`` searches + for all datasets in the given collections. + collections : `str` or `~collections.abc.Iterable` [ `str` ], optional + The collection or collections to search, in order. If not provided + or `None`, and the dataset has not already been joined into the + query, the default collection search path for this butler is used. + find_first : `bool`, optional + If `True` (default), for each result data ID, only yield one + `DatasetRef` of each `DatasetType`, from the first collection in + which a dataset of that dataset type appears (according to the + order of ``collections`` passed in). If `True`, ``collections`` + must not contain regular expressions and may not be ``...``. + + Returns + ------- + refs : `.queries.DatasetQueryResults` + Dataset references matching the given query criteria. Nested data + IDs are guaranteed to include values for all implied dimensions + (i.e. `DataCoordinate.hasFull` will return `True`), but will not + include dimension records (`DataCoordinate.hasRecords` will be + `False`) unless + `~.queries.DatasetQueryResults.with_dimension_records` is + called on the result object (which returns a new one). + + Raises + ------ + lsst.daf.butler.registry.DatasetTypeExpressionError + Raised when ``dataset_type`` expression is invalid. + TypeError + Raised when the arguments are incompatible, such as when a + collection wildcard is passed when ``find_first`` is `True`, or + when ``collections`` is `None` and default butler collections are + not defined. + + Notes + ----- + When multiple dataset types are queried in a single call, the + results of this operation are equivalent to querying for each dataset + type separately in turn, and no information about the relationships + between datasets of different types is included. + """ + if collections is None: + collections = self._driver.get_default_collections() + collections = tuple(ensure_iterable(collections)) + resolved_dataset_searches = self._driver.convert_dataset_search_args(dataset_type, collections) + single_type_results: list[SingleTypeDatasetQueryResults] = [] + for resolved_dataset_type in resolved_dataset_searches: + tree = self._tree + if resolved_dataset_type.name not in tree.datasets: + tree = tree.join_dataset( + resolved_dataset_type.name, + DatasetSearch.model_construct( + dimensions=resolved_dataset_type.dimensions.as_group(), + collections=collections, + ), + ) + elif collections is not None: + raise InvalidQueryTreeError( + f"Dataset type {resolved_dataset_type.name!r} was already joined into this query " + f"but new collections {collections!r} were still provided." + ) + spec = DatasetRefResultSpec.model_construct( + dataset_type_name=resolved_dataset_type.name, + dimensions=resolved_dataset_type.dimensions.as_group(), + storage_class_name=resolved_dataset_type.storageClass_name, + include_dimension_records=False, + find_first=find_first, + ) + single_type_results.append(SingleTypeDatasetQueryResults(self._driver, tree=tree, spec=spec)) + if len(single_type_results) == 1: + return single_type_results[0] + else: + return ChainedDatasetQueryResults(tuple(single_type_results)) + + def dimension_records(self, element: str) -> DimensionRecordQueryResults: + """Query for dimension information matching user-provided criteria. + + Parameters + ---------- + element : `str` + The name of a dimension element to obtain records for. + + Returns + ------- + records : `.queries.DimensionRecordQueryResults` + Data IDs matching the given query parameters. + """ + tree = self._tree + if element not in tree.dimensions.elements: + tree = tree.join(make_dimension_query_tree(self._driver.universe[element].minimal_group)) + result_spec = DimensionRecordResultSpec(element=self._driver.universe[element]) + return DimensionRecordQueryResults(self._driver, tree, result_spec) + + # TODO: add general, dict-row results method and QueryResults. + + def materialize( + self, + *, + dimensions: Iterable[str] | DimensionGroup | None = None, + datasets: Iterable[str] | None = None, + ) -> Query: + """Execute the query, save its results to a temporary location, and + return a new query that represents fetching or joining against those + saved results. + + Parameters + ---------- + dimensions : `~collections.abc.Iterable` [ `str` ] or \ + `DimensionGroup`, optional + Dimensions to include in the temporary results. Default is to + include all dimensions in the query. + datasets : `~collections.abc.Iterable` [ `str` ], optional + Names of dataset types that should be included in the new query; + default is to include `result_dataset_types`. Only resolved + dataset UUIDs will actually be materialized; datasets whose UUIDs + cannot be resolved will continue to be represented in the query via + a join on their dimensions. + + Returns + ------- + query : `Query` + A new query object whose that represents the materialized rows. + """ + if datasets is None: + datasets = frozenset(self.dataset_types) + else: + datasets = frozenset(datasets) + if not (datasets <= self.dataset_types): + raise InvalidQueryTreeError( + f"Dataset(s) {datasets - self.dataset_types} are present in the query." + ) + if dimensions is None: + dimensions = self._tree.dimensions + else: + dimensions = self._driver.universe.conform(dimensions) + key = self._driver.materialize(self._tree, dimensions, datasets) + tree = make_unit_query_tree(self._driver.universe).join_materialization(key, dimensions=dimensions) + for dataset_type_name in datasets: + tree = tree.join_dataset(dataset_type_name, self._tree.datasets[dataset_type_name]) + return Query(self._driver, tree) + + def join_dataset_search( + self, + dataset_type: str, + collections: Iterable[str] | None = None, + dimensions: DimensionGroup | None = None, + ) -> Query: + """Return a new query with a search for a dataset joined in. + + Parameters + ---------- + dataset_type : `str` + Name of the dataset type. May not refer to a dataset component. + collections : `~collections.abc.Iterable` [ `str` ], optional + Iterable of collections to search. Order is preserved, but will + not matter if the dataset search is only used as a constraint on + dimensions or if ``find_first=False`` when requesting results. If + not present or `None`, the default collection search path will be + used. + dimensions : `DimensionGroup`, optional + The dimensions to assume for the dataset type if it is not + registered, or check if it is. When the dataset is not registered + and this is not provided, `MissingDatasetTypeError` is raised, + since we cannot construct a query without knowing the dataset's + dimensions; providing this argument causes the returned query to + instead return no rows. + + Returns + ------- + query : `Query` + A new query object with dataset columns available and rows + restricted to those consistent with the found data IDs. + + Raises + ------ + DatasetTypeError + Raised if the dimensions were provided but they do not match the + registered dataset type. + MissingDatasetTypeError + Raised if the dimensions were not provided and the dataset type was + not registered. + """ + if collections is None: + collections = self._driver.get_default_collections() + collections = tuple(ensure_iterable(collections)) + assert isinstance(dataset_type, str), "DatasetType instances not supported here for simplicity." + try: + resolved_dimensions = self._driver.get_dataset_type(dataset_type).dimensions.as_group() + except MissingDatasetTypeError: + if dimensions is None: + raise + resolved_dimensions = dimensions + else: + if dimensions is not None and dimensions != resolved_dimensions: + raise DatasetTypeError( + f"Given dimensions {dimensions} for dataset type {dataset_type!r} do not match the " + f"registered dimensions {resolved_dimensions}." + ) + return Query( + tree=self._tree.join_dataset( + dataset_type, + DatasetSearch.model_construct(collections=collections, dimensions=resolved_dimensions), + ), + driver=self._driver, + ) + + def join_data_coordinates(self, iterable: Iterable[DataCoordinate]) -> Query: + """Return a new query that joins in an explicit table of data IDs. + + Parameters + ---------- + iterable : `~collections.abc.Iterable` [ `DataCoordinate` ] + Iterable of `DataCoordinate`. All items must have the same + dimensions. Must have at least one item. + + Returns + ------- + query : `Query` + A new query object with the data IDs joined in. + """ + rows: set[tuple[DataIdValue, ...]] = set() + dimensions: DimensionGroup | None = None + for data_coordinate in iterable: + if dimensions is None: + dimensions = data_coordinate.dimensions + elif dimensions != data_coordinate.dimensions: + raise RuntimeError(f"Inconsistent dimensions: {dimensions} != {data_coordinate.dimensions}.") + rows.add(data_coordinate.required_values) + if dimensions is None: + raise RuntimeError("Cannot upload an empty data coordinate set.") + key = self._driver.upload_data_coordinates(dimensions, rows) + return Query( + tree=self._tree.join_data_coordinate_upload(dimensions=dimensions, key=key), driver=self._driver + ) + + def join_dimensions(self, dimensions: Iterable[str] | DimensionGroup) -> Query: + """Return a new query that joins the logical tables for additional + dimensions. + + Parameters + ---------- + dimensions : `~collections.abc.Iterable` [ `str` ] or `DimensionGroup` + Names of dimensions to join in. + + Returns + ------- + query : `Query` + A new query object with the dimensions joined in. + """ + dimensions = self._driver.universe.conform(dimensions) + return Query( + tree=self._tree.join(make_dimension_query_tree(dimensions)), + driver=self._driver, + ) + + def where( + self, + *args: str | Predicate | DataId, + bind: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> Query: + """Return a query with a boolean-expression filter on its rows. + + Parameters + ---------- + *args + Constraints to apply, combined with logical AND. Arguments may be + `str` expressions to parse, `Predicate` objects (these are + typically constructed via `expression_factory`) or data IDs. + bind : `~collections.abc.Mapping` + Mapping from string identifier appearing in a string expression to + a literal value that should be substituted for it. This is + recommended instead of embedding literals directly into the + expression, especially for strings, timespans, or other types where + quoting or formatting is nontrivial. + **kwargs + Data ID key value pairs that extend and override any present in + ``*args``. + + Returns + ------- + query : `Query` + A new query object with the given row filters as well as any + already present in ``self`` (combined with logical AND). + + Notes + ----- + If an expression references a dimension or dimension element that is + not already present in the query, it will be joined in, but dataset + searches must already be joined into a query in order to reference + their fields in expressions. + + Data ID values are not checked for consistency; they are extracted from + ``args`` and then ``kwargs`` and combined, with later values overriding + earlier ones. + """ + return Query( + tree=self._tree.where( + *convert_where_args(self.dimensions, self.dataset_types, *args, bind=bind, **kwargs) + ), + driver=self._driver, + ) diff --git a/python/lsst/daf/butler/queries/convert_args.py b/python/lsst/daf/butler/queries/convert_args.py new file mode 100644 index 0000000000..b4f18e9045 --- /dev/null +++ b/python/lsst/daf/butler/queries/convert_args.py @@ -0,0 +1,244 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "convert_where_args", + "convert_order_by_args", +) + +import itertools +from collections.abc import Mapping, Set +from typing import Any, cast + +from ..dimensions import DataCoordinate, DataId, Dimension, DimensionGroup +from .expression_factory import ExpressionProxy +from .tree import ( + DATASET_FIELD_NAMES, + ColumnExpression, + DatasetFieldName, + DatasetFieldReference, + DimensionFieldReference, + DimensionKeyReference, + InvalidQueryTreeError, + OrderExpression, + Predicate, + Reversed, + make_column_literal, +) + + +def convert_where_args( + dimensions: DimensionGroup, + datasets: Set[str], + *args: str | Predicate | DataId, + bind: Mapping[str, Any] | None = None, + **kwargs: Any, +) -> Predicate: + """Convert ``where`` arguments to a sequence of column expressions. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions already present in the query this filter is being applied + to. Returned predicates may reference dimensions outside this set. + datasets : `~collections.abc.Set` [ `str` ] + Dataset types already present in the query this filter is being applied + to. Returned predicates may still reference datasets outside this set; + this may be an error at a higher level, but it is not necessarily + checked here. + *args : `str`, `Predicate`, `DataCoordinate`, or `~collections.abc.Mapping` + Expressions to convert into predicates. + bind : `~collections.abc.Mapping`, optional + Mapping from identifier to literal value used when parsing string + expressions. + **kwargs : `object` + Additional data ID key-value pairs. + + Returns + ------- + predicate : `Predicate` + Standardized predicate object. + + Notes + ----- + Data ID values are not checked for consistency; they are extracted from + args and then kwargs and combined, with later extractions taking + precedence. + """ + result = Predicate.from_bool(True) + data_id_dict: dict[str, Any] = {} + for arg in args: + match arg: + case str(): + raise NotImplementedError("TODO: plug in registry.queries.expressions.parser") + case Predicate(): + result = result.logical_and(arg) + case DataCoordinate(): + data_id_dict.update(arg.mapping) + case _: + data_id_dict.update(arg) + data_id_dict.update(kwargs) + for k, v in data_id_dict.items(): + result = result.logical_and( + Predicate.compare( + DimensionKeyReference.model_construct(dimension=dimensions.universe.dimensions[k]), + "==", + make_column_literal(v), + ) + ) + return result + + +def convert_order_by_args( + dimensions: DimensionGroup, datasets: Set[str], *args: str | OrderExpression | ExpressionProxy +) -> tuple[OrderExpression, ...]: + """Convert ``order_by`` arguments to a sequence of column expressions. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions already present in the query whose rows are being sorted. + Returned expressions may reference dimensions outside this set; this + may be an error at a higher level, but it is not necessarily checked + here. + datasets : `~collections.abc.Set` [ `str` ] + Dataset types already present in the query whose rows are being sorted. + Returned expressions may reference datasets outside this set; this may + be an error at a higher level, but it is not necessarily checked here. + *args : `OrderExpression`, `str`, or `ExpressionObject` + Expression or column names to sort by. + + Returns + ------- + expressions : `tuple` [ `OrderExpression`, ... ] + Standardized expression objects. + """ + result: list[OrderExpression] = [] + for arg in args: + match arg: + case str(): + reverse = False + if arg.startswith("-"): + reverse = True + arg = arg[1:] + arg = interpret_identifier(dimensions, datasets, arg, {}) + if reverse: + arg = Reversed.model_construct(operand=arg) + case ExpressionProxy(): + arg = arg._expression + if not hasattr(arg, "expression_type"): + raise TypeError(f"Unrecognized order-by argument: {arg!r}.") + result.append(arg) + return tuple(result) + + +def interpret_identifier( + dimensions: DimensionGroup, datasets: Set[str], identifier: str, bind: Mapping[str, Any] +) -> ColumnExpression: + """Associate an identifier in a ``where`` or ``order_by`` expression with + a query column or bind literal. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions already present in the query this filter is being applied + to. Returned expressions may reference dimensions outside this set. + datasets : `~collections.abc.Set` [ `str` ] + Dataset types already present in the query this filter is being applied + to. Returned expressions may still reference datasets outside this + set. + identifier : `str` + String identifier to process. + bind : `~collections.abc.Mapping` [ `str`, `object` ] + Dictionary of bind literals to match identifiers against first. + + Returns + ------- + expression : `ColumnExpression` + Column expression corresponding to the identifier. + """ + if identifier in bind: + return make_column_literal(bind[identifier]) + first, _, second = identifier.partition(".") + if not second: + if first in dimensions.universe.dimensions: + return DimensionKeyReference.model_construct(dimension=dimensions.universe.dimensions[first]) + else: + element_matches: set[str] = set() + for element_name in dimensions.elements: + element = dimensions.universe[element_name] + if first in element.schema.names: + element_matches.add(element_name) + if first in DATASET_FIELD_NAMES: + dataset_matches = set(datasets) + else: + dataset_matches = set() + if len(element_matches) + len(dataset_matches) > 1: + match_str = ", ".join( + f"'{x}.{first}'" for x in sorted(itertools.chain(element_matches, dataset_matches)) + ) + raise InvalidQueryTreeError( + f"Ambiguous identifier {first!r} matches multiple fields: {match_str}." + ) + elif element_matches: + element = dimensions.universe[element_matches.pop()] + return DimensionFieldReference.model_construct(element=element, field=first) + elif dataset_matches: + return DatasetFieldReference.model_construct( + dataset_type=dataset_matches.pop(), field=cast(DatasetFieldName, first) + ) + else: + if first in dimensions.universe.elements: + element = dimensions.universe[first] + if second in element.schema.dimensions.names: + if isinstance(element, Dimension) and second == element.primary_key.name: + # Identifier is something like "visit.id" which we want to + # interpret the same way as just "visit". + return DimensionKeyReference.model_construct(dimension=element) + else: + # Identifier is something like "visit.instrument", which we + # want to interpret the same way as just "instrument". + dimension = dimensions.universe.dimensions[second] + return DimensionKeyReference.model_construct(dimension=dimension) + elif second in element.schema.remainder.names: + return DimensionFieldReference.model_construct(element=element, field=second) + else: + raise InvalidQueryTreeError(f"Unrecognized field {second!r} for {first}.") + elif second in DATASET_FIELD_NAMES: + # We just assume the dataset type is okay; it's the job of + # higher-level code to complain othewise. + return DatasetFieldReference.model_construct( + dataset_type=first, field=cast(DatasetFieldName, second) + ) + elif first in datasets: + raise InvalidQueryTreeError( + f"Identifier {identifier!r} references dataset type {first!r} but field " + f"{second!r} is not a valid for datasets." + ) + raise InvalidQueryTreeError(f"Unrecognized identifier {identifier!r}.") diff --git a/python/lsst/daf/butler/queries/driver.py b/python/lsst/daf/butler/queries/driver.py new file mode 100644 index 0000000000..3a07e2dc42 --- /dev/null +++ b/python/lsst/daf/butler/queries/driver.py @@ -0,0 +1,512 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "QueryDriver", + "PageKey", + "ResultPage", + "DataCoordinateResultPage", + "DimensionRecordResultPage", + "DatasetRefResultPage", + "GeneralResultPage", +) + +import dataclasses +import uuid +from abc import abstractmethod +from collections.abc import Iterable, Sequence +from contextlib import AbstractContextManager +from types import EllipsisType +from typing import Annotated, Any, TypeAlias, Union, overload + +import pydantic +from lsst.utils.iteration import ensure_iterable + +from .._dataset_ref import DatasetRef +from .._dataset_type import DatasetType +from ..dimensions import ( + DataCoordinate, + DataIdValue, + DimensionGroup, + DimensionRecord, + DimensionRecordSet, + DimensionRecordTable, + DimensionUniverse, +) +from ..registry import CollectionSummary, DatasetTypeError, DatasetTypeExpressionError +from ..registry.interfaces import CollectionRecord +from .result_specs import ( + DataCoordinateResultSpec, + DatasetRefResultSpec, + DimensionRecordResultSpec, + GeneralResultSpec, + ResultSpec, +) +from .tree import ColumnSet, DataCoordinateUploadKey, MaterializationKey, QueryTree + +PageKey: TypeAlias = uuid.UUID + + +class DataCoordinateResultPage(pydantic.BaseModel): + """A single page of results from a data coordinate query.""" + + spec: DataCoordinateResultSpec + next_key: PageKey | None + + # TODO: On DM-41114 this will become a custom container that normalizes out + # attached DimensionRecords and is Pydantic-friendly. Right now this model + # isn't actually serializable. + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + rows: list[DataCoordinate] + + +@dataclasses.dataclass +class DimensionRecordResultPage: + """A single page of results from a dimension record query.""" + + spec: DimensionRecordResultSpec + next_key: PageKey | None + rows: Iterable[DimensionRecord] + + def as_table(self) -> DimensionRecordTable: + if isinstance(self.rows, DimensionRecordTable): + return self.rows + else: + return DimensionRecordTable(self.spec.element, self.rows) + + def as_set(self) -> DimensionRecordSet: + if isinstance(self.rows, DimensionRecordSet): + return self.rows + else: + return DimensionRecordSet(self.spec.element, self.rows) + + +class DatasetRefResultPage(pydantic.BaseModel): + """A single page of results from a dataset ref query.""" + + spec: DatasetRefResultSpec + next_key: PageKey | None + + # TODO: On DM-41115 this will become a custom container that normalizes out + # attached DimensionRecords and is Pydantic-friendly. Right now this model + # isn't actually serializable. + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + rows: list[DatasetRef] + + +class GeneralResultPage(pydantic.BaseModel): + """A single page of results from a general query.""" + + spec: GeneralResultSpec + next_key: PageKey | None + + # Raw tabular data, with columns in the same order as spec.columns. + rows: list[tuple[Any, ...]] + + +ResultPage: TypeAlias = Annotated[ + Union[DataCoordinateResultPage, DimensionRecordResultPage, DatasetRefResultPage, GeneralResultPage], + pydantic.Field(discriminator=lambda x: x.spec.result_type), +] + + +class QueryDriver(AbstractContextManager[None]): + """Base class for the implementation object inside `Query2` objects + that is specialized for DirectButler vs. RemoteButler. + + Notes + ----- + Implementations should be context managers. This allows them to manage the + lifetime of server-side state, such as: + + - a SQL transaction, when necessary (DirectButler); + - SQL cursors for queries that were not fully iterated over (DirectButler); + - temporary database tables (DirectButler); + - result-page Parquet files that were never fetched (RemoteButler); + - uploaded Parquet files used to fill temporary database tables + (RemoteButler); + - cached content needed to construct query trees, like collection summaries + (potentially all Butlers). + + When possible, these sorts of things should be cleaned up earlier when they + are no longer needed, and the Butler server will still have to guard + against the context manager's ``__exit__`` signal never reaching it, but a + context manager will take care of these much more often than relying on + garbage collection and ``__del__`` would. + """ + + @property + @abstractmethod + def universe(self) -> DimensionUniverse: + """Object that defines all dimensions.""" + raise NotImplementedError() + + @overload + def execute(self, result_spec: DataCoordinateResultSpec, tree: QueryTree) -> DataCoordinateResultPage: ... + + @overload + def execute( + self, result_spec: DimensionRecordResultSpec, tree: QueryTree + ) -> DimensionRecordResultPage: ... + + @overload + def execute(self, result_spec: DatasetRefResultSpec, tree: QueryTree) -> DatasetRefResultPage: ... + + @overload + def execute(self, result_spec: GeneralResultSpec, tree: QueryTree) -> GeneralResultPage: ... + + @abstractmethod + def execute(self, result_spec: ResultSpec, tree: QueryTree) -> ResultPage: + """Execute a query and return the first result page. + + Parameters + ---------- + result_spec : `ResultSpec` + The kind of results the user wants from the query. This can affect + the actual query (i.e. SQL and Python postprocessing) that is run, + e.g. by changing what is in the SQL SELECT clause and even what + tables are joined in, but it never changes the number or order of + result rows. + tree : `QueryTree` + Query tree to evaluate. + + Returns + ------- + first_page : `ResultPage` + A page whose type corresponds to the type of ``result_spec``, with + at least the initial rows from the query. This should have an + empty ``rows`` attribute if the query returned no results, and a + ``next_key`` attribute that is not `None` if there were more + results than could be returned in a single page. + """ + raise NotImplementedError() + + @overload + def fetch_next_page( + self, result_spec: DataCoordinateResultSpec, key: PageKey + ) -> DataCoordinateResultPage: ... + + @overload + def fetch_next_page( + self, result_spec: DimensionRecordResultSpec, key: PageKey + ) -> DimensionRecordResultPage: ... + + @overload + def fetch_next_page(self, result_spec: DatasetRefResultSpec, key: PageKey) -> DatasetRefResultPage: ... + + @overload + def fetch_next_page(self, result_spec: GeneralResultSpec, key: PageKey) -> GeneralResultPage: ... + + @abstractmethod + def fetch_next_page(self, result_spec: ResultSpec, key: PageKey) -> ResultPage: + """Fetch the next page of results from an already-executed query. + + Parameters + ---------- + result_spec : `ResultSpec` + The kind of results the user wants from the query. This must be + identical to the ``result_spec`` passed to `execute`, but + implementations are not *required* to check this. + key : `PageKey` + Key included in the previous page from this query. This key may + become unusable or even be reused after this call. + + Returns + ------- + next_page : `ResultPage` + The next page of query results. + """ + # We can put off dealing with pagination initially by just making an + # implementation of this method raise. + # + # In RemoteButler I expect this to work by having the call to execute + # continue to write Parquet files (or whatever) to some location until + # its cursor is exhausted, and then delete those files as they are + # fetched (or, failing that, when receiving a signal from + # ``__exit__``). + # + # In DirectButler I expect to have a dict[PageKey, Cursor], fetch a + # blocks of rows from it, and just reuse the page key for the next page + # until the cursor is exactly. + raise NotImplementedError() + + @abstractmethod + def materialize( + self, + tree: QueryTree, + dimensions: DimensionGroup, + datasets: frozenset[str], + ) -> MaterializationKey: + """Execute a query tree, saving results to temporary storage for use + in later queries. + + Parameters + ---------- + tree : `QueryTree` + Query tree to evaluate. + dimensions : `DimensionGroup` + Dimensions whose key columns should be preserved. + datasets : `frozenset` [ `str` ] + Names of dataset types whose ID columns may be materialized. It + is implementation-defined whether they actually are. + + Returns + ------- + key : `MaterializationKey` + Unique identifier for the result rows that allows them to be + referenced in a `QueryTree`. + """ + raise NotImplementedError() + + @abstractmethod + def upload_data_coordinates( + self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]] + ) -> DataCoordinateUploadKey: + """Upload a table of data coordinates for use in later queries. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions of the data coordinates. + rows : `Iterable` [ `tuple` ] + Tuples of data coordinate values, covering just the "required" + subset of ``dimensions``. + + Returns + ------- + key + Unique identifier for the upload that allows it to be referenced in + a `QueryTree`. + """ + raise NotImplementedError() + + @abstractmethod + def count( + self, + tree: QueryTree, + columns: ColumnSet, + find_first_dataset: str | None, + *, + exact: bool, + discard: bool, + ) -> int: + """Return the number of rows a query would return. + + Parameters + ---------- + tree : `QueryTree` + Query tree to evaluate. + columns : `ColumnSet` + Columns over which rows should have unique values before they are + counted. + find_first_dataset : `str` or `None` + Perform a search for this dataset type to reject all but the first + result in the collection search path for each data ID, before + counting the result rows. + exact : `bool`, optional + If `True`, run the full query and perform post-query filtering if + needed to account for that filtering in the count. If `False`, the + result may be an upper bound. + discard : `bool`, optional + If `True`, compute the exact count even if it would require running + the full query and then throwing away the result rows after + counting them. If `False`, this is an error, as the user would + usually be better off executing the query first to fetch its rows + into a new query (or passing ``exact=False``). Ignored if + ``exact=False``. + """ + raise NotImplementedError() + + @abstractmethod + def any(self, tree: QueryTree, *, execute: bool, exact: bool) -> bool: + """Test whether the query would return any rows. + + Parameters + ---------- + tree : `QueryTree` + Query tree to evaluate. + execute : `bool`, optional + If `True`, execute at least a ``LIMIT 1`` query if it cannot be + determined prior to execution that the query would return no rows. + exact : `bool`, optional + If `True`, run the full query and perform post-query filtering if + needed, until at least one result row is found. If `False`, the + returned result does not account for post-query filtering, and + hence may be `True` even when all result rows would be filtered + out. + + Returns + ------- + any : `bool` + `True` if the query would (or might, depending on arguments) yield + result rows. `False` if it definitely would not. + """ + raise NotImplementedError() + + @abstractmethod + def explain_no_results(self, tree: QueryTree, execute: bool) -> Iterable[str]: + """Return human-readable messages that may help explain why the query + yields no results. + + Parameters + ---------- + tree : `QueryTree` + Query tree to evaluate. + execute : `bool`, optional + If `True` (default) execute simplified versions (e.g. ``LIMIT 1``) + of aspects of the tree to more precisely determine where rows were + filtered out. + + Returns + ------- + messages : `~collections.abc.Iterable` [ `str` ] + String messages that describe reasons the query might not yield any + results. + """ + raise NotImplementedError() + + @abstractmethod + def get_default_collections(self) -> tuple[str, ...]: + """Return the default collection search path. + + Returns + ------- + collections : `tuple` [ `str`, ... ] + The default collection search path as a tuple of `str`. + + Raises + ------ + NoDefaultCollectionError + Raised if there are no default collections. + """ + raise NotImplementedError() + + @abstractmethod + def resolve_collection_path( + self, collections: Sequence[str] + ) -> list[tuple[CollectionRecord, CollectionSummary]]: + """Process a collection search path argument into a `list` of + collection records and summaries. + + Parameters + ---------- + collections : `~collections.abc.Sequence` [ `str` ] + The collection or collections to search. + + Returns + ------- + collection_info : `list` [ `tuple` [ `CollectionRecord`, \ + `CollectionSummary` ] ] + A `list` of pairs of `CollectionRecord` and `CollectionSummary` + that flattens out all `~CollectionType.CHAINED` collections into + their children while maintaining the same order and avoiding + duplicates. + + Raises + ------ + MissingCollectionError + Raised if any collection in ``collections`` does not exist. + + Notes + ----- + Implementations are generally expected to cache the collection records + and summaries they obtain (including the records for + `~CollectionType.CHAINED` collections that are not returned) in order + to optimize multiple calls with collections in common. + """ + raise NotImplementedError() + + @abstractmethod + def get_dataset_type(self, name: str) -> DatasetType: + """Return the dimensions for a dataset type. + + Parameters + ---------- + name : `str` + Name of the dataset type. + + Returns + ------- + dataset_type : `DatasetType` + Dimensions of the dataset type. + + Raises + ------ + MissingDatasetTypeError + Raised if the dataset type is not registered. + """ + raise NotImplementedError() + + def convert_dataset_search_args( + self, + dataset_type: str | DatasetType | Iterable[str | DatasetType] | EllipsisType, + collections: Sequence[str], + ) -> list[DatasetType]: + """Resolve dataset type and collections argument. + + Parameters + ---------- + dataset_type : `str`, `DatasetType`, \ + `~collections.abc.Iterable` [ `str` or `DatasetType` ], \ + or ``...`` + The dataset type or types to search for. Passing ``...`` searches + for all datasets in the given collections. + collections : `~collections.abc.Sequence` [ `str` ] + The collection or collections to search. + + Returns + ------- + resolved : `list` [ `DatasetType` ] + Matching dataset types. + """ + if dataset_type is ...: + dataset_type = set() + for _, summary in self.resolve_collection_path(collections): + dataset_type.update(summary.dataset_types.names) + result: list[DatasetType] = [] + for arg in ensure_iterable(dataset_type): + given_dataset_type: DatasetType | None + if isinstance(arg, str): + dataset_type_name = arg + given_dataset_type = None + elif isinstance(arg, DatasetType): + dataset_type_name = arg.name + given_dataset_type = arg + else: + raise DatasetTypeExpressionError(f"Unsupported object {arg} in dataset type expression.") + resolved_dataset_type: DatasetType = self.get_dataset_type(dataset_type_name) + if given_dataset_type is not None and not given_dataset_type.is_compatible_with( + resolved_dataset_type + ): + raise DatasetTypeError( + f"Given dataset type {given_dataset_type} is not compatible with the " + f"registered version {resolved_dataset_type}." + ) + result.append(resolved_dataset_type) + return result diff --git a/python/lsst/daf/butler/queries/expression_factory.py b/python/lsst/daf/butler/queries/expression_factory.py new file mode 100644 index 0000000000..32cec1e2c5 --- /dev/null +++ b/python/lsst/daf/butler/queries/expression_factory.py @@ -0,0 +1,428 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("ExpressionFactory", "ExpressionProxy", "ScalarExpressionProxy", "TimespanProxy", "RegionProxy") + +from collections.abc import Iterable +from typing import TYPE_CHECKING + +from lsst.sphgeom import Region + +from ..dimensions import DimensionElement, DimensionUniverse +from . import tree + +if TYPE_CHECKING: + from .._timespan import Timespan + from ._query import Query + +# This module uses ExpressionProxy and its subclasses to wrap ColumnExpression, +# but it just returns OrderExpression and Predicate objects directly, because +# we don't need to overload any operators or define any methods on those. + + +class ExpressionProxy: + """A wrapper for column expressions that overloads comparison operators + to return new expression proxies. + + Parameters + ---------- + expression : `tree.ColumnExpression` + Underlying expression object. + """ + + def __init__(self, expression: tree.ColumnExpression): + self._expression = expression + + def __repr__(self) -> str: + return str(self._expression) + + @property + def is_null(self) -> tree.Predicate: + """A boolean expression that tests whether this expression is NULL.""" + return tree.Predicate.is_null(self._expression) + + @staticmethod + def _make_expression(other: object) -> tree.ColumnExpression: + if isinstance(other, ExpressionProxy): + return other._expression + else: + return tree.make_column_literal(other) + + def _make_comparison(self, other: object, operator: tree.ComparisonOperator) -> tree.Predicate: + return tree.Predicate.compare(a=self._expression, b=self._make_expression(other), operator=operator) + + +class ScalarExpressionProxy(ExpressionProxy): + """An `ExpressionProxy` specialized for simple single-value columns.""" + + @property + def desc(self) -> tree.Reversed: + """An ordering expression that indicates that the sort on this + expression should be reversed. + """ + return tree.Reversed.model_construct(operand=self._expression) + + def __eq__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, "==") + + def __ne__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, "!=") + + def __lt__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, "<") + + def __le__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, "<=") + + def __gt__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, ">") + + def __ge__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, ">=") + + def in_range(self, start: int = 0, stop: int | None = None, step: int = 1) -> tree.Predicate: + """Return a boolean expression that tests whether this expression is + within a literal integer range. + + Parameters + ---------- + start : `int`, optional + Lower bound (inclusive) for the slice. + stop : `int` or `None`, optional + Upper bound (exclusive) for the slice, or `None` for no bound. + step : `int`, optional + Spacing between integers in the range. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return tree.Predicate.in_range(self._expression, start=start, stop=stop, step=step) + + def in_iterable(self, others: Iterable) -> tree.Predicate: + """Return a boolean expression that tests whether this expression + evaluates to a value that is in an iterable of other expressions. + + Parameters + ---------- + others : `collections.abc.Iterable` + An iterable of `ExpressionProxy` or values to be interpreted as + literals. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return tree.Predicate.in_container(self._expression, [self._make_expression(item) for item in others]) + + def in_query(self, column: ExpressionProxy, query: Query) -> tree.Predicate: + """Return a boolean expression that test whether this expression + evaluates to a value that is in a single-column selection from another + query. + + Parameters + ---------- + column : `ExpressionProxy` + Proxy for the column to extract from ``query``. + query : `RelationQuery` + Query to select from. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return tree.Predicate.in_query_tree(self._expression, column._expression, query._tree) + + +class TimespanProxy(ExpressionProxy): + """An `ExpressionProxy` specialized for timespan columns and literals.""" + + @property + def begin(self) -> ExpressionProxy: + """An expression representing the lower bound (inclusive).""" + return ExpressionProxy( + tree.UnaryExpression.model_construct(operand=self._expression, operator="begin_of") + ) + + @property + def end(self) -> ExpressionProxy: + """An expression representing the upper bound (exclusive).""" + return ExpressionProxy( + tree.UnaryExpression.model_construct(operand=self._expression, operator="end_of") + ) + + def overlaps(self, other: TimespanProxy | Timespan) -> tree.Predicate: + """Return a boolean expression representing an overlap test between + this timespan and another. + + Parameters + ---------- + other : `TimespanProxy` or `Timespan` + Expression or literal to compare to. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return self._make_comparison(other, "overlaps") + + +class RegionProxy(ExpressionProxy): + """An `ExpressionProxy` specialized for region columns and literals.""" + + def overlaps(self, other: RegionProxy | Region) -> tree.Predicate: + """Return a boolean expression representing an overlap test between + this region and another. + + Parameters + ---------- + other : `RegionProxy` or `Region` + Expression or literal to compare to. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return self._make_comparison(other, "overlaps") + + +class DimensionElementProxy: + """An expression-creation proxy for a dimension element logical table. + + Parameters + ---------- + element : `DimensionElement` + Element this object wraps. + + Notes + ----- + The (dynamic) attributes of this object are expression proxies for the + non-dimension fields of the element's records. + """ + + def __init__(self, element: DimensionElement): + self._element = element + + def __repr__(self) -> str: + return self._element.name + + def __getattr__(self, field: str) -> ExpressionProxy: + expression = tree.DimensionFieldReference(element=self._element.name, field=field) + match field: + case "region": + return RegionProxy(expression) + case "timespan": + return TimespanProxy(expression) + return ScalarExpressionProxy(expression) + + def __dir__(self) -> list[str]: + result = list(super().__dir__()) + result.extend(self._element.RecordClass.fields.facts.names) + if self._element.spatial: + result.append("region") + if self._element.temporal: + result.append("temporal") + return result + + +class DimensionProxy(ScalarExpressionProxy, DimensionElementProxy): + """An expression-creation proxy for a dimension logical table. + + Parameters + ---------- + dimension : `DimensionElement` + Element this object wraps. + + Notes + ----- + This class combines record-field attribute access from `DimensionElement` + proxy with direct interpretation as a dimension key column via + `ScalarExpressionProxy`. For example:: + + x = query.expression_factory + query.where( + x.detector.purpose == "SCIENCE", # field access + x.detector > 100, # direct usage as an expression + ) + """ + + def __init__(self, dimension: DimensionElement): + ScalarExpressionProxy.__init__(self, tree.DimensionKeyReference(dimension=dimension.name)) + DimensionElementProxy.__init__(self, dimension) + + +class DatasetTypeProxy: + """An expression-creation proxy for a dataset type's logical table. + + Parameters + ---------- + dataset_type : `str` + Dataset type name or wildcard. Wildcards are usable only when the + query contains exactly one dataset type or a wildcard. + + Notes + ----- + The attributes of this object are expression proxies for the fields + associated with datasets rather than their dimensions. + """ + + def __init__(self, dataset_type: str): + self._dataset_type = dataset_type + + def __repr__(self) -> str: + return self._dataset_type + + # Attributes are actually fixed, but we implement them with __getattr__ + # and __dir__ to avoid repeating the list. And someday they might expand + # to include Datastore record fields. + + def __getattr__(self, field: str) -> ExpressionProxy: + if field not in tree.DATASET_FIELD_NAMES: + raise AttributeError(field) + expression = tree.DatasetFieldReference(dataset_type=self._dataset_type, field=field) + if field == "timespan": + return TimespanProxy(expression) + return ScalarExpressionProxy(expression) + + def __dir__(self) -> list[str]: + result = list(super().__dir__()) + result.extend(tree.DATASET_FIELD_NAMES) + return result + + +class ExpressionFactory: + """A factory for creating column expressions that uses operator overloading + to form a mini-language. + + Instances of this class are usually obtained from + `RelationQuery.expression_factory`; see that property's documentation for + more information. + + Parameters + ---------- + universe : `DimensionUniverse` + Object that describes all dimensions. + """ + + def __init__(self, universe: DimensionUniverse): + self._universe = universe + + def __getattr__(self, name: str) -> DimensionElementProxy: + element = self._universe.elements[name] + if element in self._universe.dimensions: + return DimensionProxy(element) + return DimensionElementProxy(element) + + def __getitem__(self, name: str) -> DatasetTypeProxy: + return DatasetTypeProxy(name) + + def not_(self, operand: tree.Predicate) -> tree.Predicate: + """Apply a logical NOT operation to a boolean expression. + + Parameters + ---------- + operand : `tree.Predicate` + Expression to invetree. + + Returns + ------- + logical_not : `tree.Predicate` + A boolean expression that evaluates to the opposite of ``operand``. + """ + return operand.logical_not() + + def all(self, first: tree.Predicate, /, *args: tree.Predicate) -> tree.Predicate: + """Combine a sequence of boolean expressions with logical AND. + + Parameters + ---------- + first : `tree.Predicate` + First operand (required). + *args + Additional operands. + + Returns + ------- + logical_and : `tree.Predicate` + A boolean expression that evaluates to `True` only if all operands + evaluate to `True. + """ + return first.logical_and(*args) + + def any(self, first: tree.Predicate, /, *args: tree.Predicate) -> tree.Predicate: + """Combine a sequence of boolean expressions with logical OR. + + Parameters + ---------- + first : `tree.Predicate` + First operand (required). + *args + Additional operands. + + Returns + ------- + logical_or : `tree.Predicate` + A boolean expression that evaluates to `True` if any operand + evaluates to `True. + """ + return first.logical_or(*args) + + @staticmethod + def literal(value: object) -> ExpressionProxy: + """Return an expression proxy that represents a literal value. + + Expression proxy objects obtained from this factory can generally be + compared directly to literals, so calling this method directly in user + code should rarely be necessary. + + Parameters + ---------- + value : `object` + Value to include as a literal in an expression tree. + + Returns + ------- + expression : `ExpressionProxy` + Expression wrapper for this literal. + """ + expression = tree.make_column_literal(value) + match expression.expression_type: + case "timespan": + return TimespanProxy(expression) + case "region": + return RegionProxy(expression) + case "bool": + raise NotImplementedError("Boolean literals are not supported.") + case _: + return ScalarExpressionProxy(expression) diff --git a/python/lsst/daf/butler/queries/overlaps.py b/python/lsst/daf/butler/queries/overlaps.py new file mode 100644 index 0000000000..7f92038cf6 --- /dev/null +++ b/python/lsst/daf/butler/queries/overlaps.py @@ -0,0 +1,466 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("OverlapsVisitor",) + +import itertools +from collections.abc import Hashable, Iterable, Sequence, Set +from typing import Generic, Literal, TypeVar, cast + +from lsst.sphgeom import Region + +from .._topology import TopologicalFamily +from ..dimensions import DimensionElement, DimensionGroup +from . import tree +from .visitors import PredicateVisitFlags, SimplePredicateVisitor + +_T = TypeVar("_T", bound=Hashable) + + +class _NaiveDisjointSet(Generic[_T]): + """A very naive (but simple) implementation of a "disjoint set" data + structure for strings, with mostly O(N) performance. + + This class should not be used in any context where the number of elements + in the data structure is large. It intentionally implements a subset of + the interface of `scipy.cluster.DisJointSet` so that non-naive + implementation could be swapped in if desired. + + Parameters + ---------- + superset : `~collections.abc.Iterable` [ `str` ] + Elements to initialize the disjoint set, with each in its own + single-element subset. + """ + + def __init__(self, superset: Iterable[_T]): + self._subsets = [{k} for k in superset] + self._subsets.sort(key=len, reverse=True) + + def add(self, k: _T) -> bool: # numpydoc ignore=PR04 + """Add a new element as its own single-element subset unless it is + already present. + + Parameters + ---------- + k + Value to add. + + Returns + ------- + added : `bool`: + `True` if the value was actually added, `False` if it was already + present. + """ + for subset in self._subsets: + if k in subset: + return False + self._subsets.append({k}) + return True + + def merge(self, a: _T, b: _T) -> bool: # numpydoc ignore=PR04 + """Merge the subsets containing the given elements. + + Parameters + ---------- + a : + Element whose subset should be merged. + b : + Element whose subset should be merged. + + Returns + ------- + merged : `bool` + `True` if a merge occurred, `False` if the elements were already in + the same subset. + """ + for i, subset in enumerate(self._subsets): + if a in subset: + break + else: + raise KeyError(f"Merge argument {a!r} not in disjoin set {self._subsets}.") + for j, subset in enumerate(self._subsets): + if b in subset: + break + else: + raise KeyError(f"Merge argument {b!r} not in disjoin set {self._subsets}.") + if i == j: + return False + i, j = sorted((i, j)) + self._subsets[i].update(self._subsets[j]) + del self._subsets[j] + self._subsets.sort(key=len, reverse=True) + return True + + def subsets(self) -> Sequence[Set[_T]]: + """Return the current subsets, ordered from largest to smallest.""" + return self._subsets + + @property + def n_subsets(self) -> int: + """The number of subsets.""" + return len(self._subsets) + + +class OverlapsVisitor(SimplePredicateVisitor): + """A helper class for dealing with spatial and temporal overlaps in a + query. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions of the query. + + Notes + ----- + This class includes logic for extracting explicit spatial and temporal + joins from a WHERE-clause predicate and computing automatic joins given the + dimensions of the query. It is designed to be subclassed by query driver + implementations that want to rewrite the predicate at the same time. + """ + + def __init__(self, dimensions: DimensionGroup): + self.dimensions = dimensions + self._spatial_connections = _NaiveDisjointSet(self.dimensions.spatial) + self._temporal_connections = _NaiveDisjointSet(self.dimensions.temporal) + + def run(self, predicate: tree.Predicate, join_operands: Iterable[DimensionGroup]) -> tree.Predicate: + """Process the given predicate to extract spatial and temporal + overlaps. + + Parameters + ---------- + predicate : `tree.Predicate` + Predicate to process. + join_operands : `~collections.abc.Iterable` [ `DimensionGroup` ] + The dimensions of logical tables being joined into this query; + these can included embedded spatial and temporal joins that can + make it unnecessary to add new ones. + + Returns + ------- + predicate : `tree.Predicate` + A possibly-modified predicate that should replace the original. + """ + result = predicate.visit(self) + if result is None: + result = predicate + for join_operand_dimensions in join_operands: + self.add_join_operand_connections(join_operand_dimensions) + for a, b in self.compute_automatic_spatial_joins(): + join_predicate = self.visit_spatial_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS) + if join_predicate is None: + join_predicate = tree.Predicate.compare( + tree.DimensionFieldReference.model_construct(element=a, field="region"), + "overlaps", + tree.DimensionFieldReference.model_construct(element=b, field="region"), + ) + result = result.logical_and(join_predicate) + for a, b in self.compute_automatic_temporal_joins(): + join_predicate = self.visit_temporal_dimension_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS) + if join_predicate is None: + join_predicate = tree.Predicate.compare( + tree.DimensionFieldReference.model_construct(element=a, field="timespan"), + "overlaps", + tree.DimensionFieldReference.model_construct(element=b, field="timespan"), + ) + result = result.logical_and(join_predicate) + return result + + def visit_comparison( + self, + a: tree.ColumnExpression, + operator: tree.ComparisonOperator, + b: tree.ColumnExpression, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + if operator == "overlaps": + if a.column_type == "region": + return self.visit_spatial_overlap(a, b, flags) + elif b.column_type == "timespan": + return self.visit_temporal_overlap(a, b, flags) + else: + raise AssertionError(f"Unexpected column type {a.column_type} for overlap.") + return None + + def add_join_operand_connections(self, operand_dimensions: DimensionGroup) -> None: + """Add overlap connections implied by a table or subquery. + + Parameters + ---------- + operand_dimensions : `DimensionGroup` + Dimensions of of the table or subquery. + + Notes + ----- + We assume each join operand to a `tree.Select` has its own + complete set of spatial and temporal joins that went into generating + its rows. That will naturally be true for relations originating from + the butler database, like dataset searches and materializations, and if + it isn't true for a data ID upload, that would represent an intentional + association between non-overlapping things that we'd want to respect by + *not* adding a more restrictive automatic join. + """ + for a_family, b_family in itertools.pairwise(operand_dimensions.spatial): + self._spatial_connections.merge(a_family, b_family) + for a_family, b_family in itertools.pairwise(operand_dimensions.temporal): + self._temporal_connections.merge(a_family, b_family) + + def compute_automatic_spatial_joins(self) -> list[tuple[DimensionElement, DimensionElement]]: + """Return pairs of dimension elements that should be spatially joined. + + Returns + ------- + joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ] + Automatic joins. + + Notes + ----- + This takes into account explicit joins extracted by `process` and + implicit joins added by `add_join_operand_connections`, and only + returns additional joins if there is an unambiguous way to spatially + connect any dimensions that are not already spatially connected. + Automatic joins are always the most fine-grained join between sets of + dimensions (i.e. ``visit_detector_region`` and ``patch`` instead of + ``visit`` and ``tract``), but explicitly adding a coarser join between + sets of elements will prevent the fine-grained join from being added. + """ + return self._compute_automatic_joins("spatial", self._spatial_connections) + + def compute_automatic_temporal_joins(self) -> list[tuple[DimensionElement, DimensionElement]]: + """Return pairs of dimension elements that should be spatially joined. + + Returns + ------- + joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ] + Automatic joins. + + Notes + ----- + See `compute_automatic_spatial_joins` for information on how automatic + joins are determined. Joins to dataset validity ranges are never + automatic. + """ + return self._compute_automatic_joins("temporal", self._temporal_connections) + + def _compute_automatic_joins( + self, kind: Literal["spatial", "temporal"], connections: _NaiveDisjointSet[TopologicalFamily] + ) -> list[tuple[DimensionElement, DimensionElement]]: + if connections.n_subsets == 1: + # All of the joins we need are already present. + return [] + if connections.n_subsets > 2: + raise tree.InvalidQueryTreeError( + f"Too many disconnected sets of {kind} families for an automatic " + f"join: {connections.subsets()}. Add explicit {kind} joins to avoid this error." + ) + a_subset, b_subset = connections.subsets() + if len(a_subset) > 1 or len(b_subset) > 1: + raise tree.InvalidQueryTreeError( + f"A {kind} join is needed between {a_subset} and {b_subset}, but which join to " + "add is ambiguous. Add an explicit spatial join to avoid this error." + ) + # We have a pair of families that are not explicitly or implicitly + # connected to any other families; add an automatic join between their + # most fine-grained members. + (a_family,) = a_subset + (b_family,) = b_subset + return [ + ( + cast(DimensionElement, a_family.choose(self.dimensions.elements, self.dimensions.universe)), + cast(DimensionElement, b_family.choose(self.dimensions.elements, self.dimensions.universe)), + ) + ] + + def visit_spatial_overlap( + self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + """Dispatch a spatial overlap comparison predicate to handlers. + + This method should rarely (if ever) need to be overridden. + + Parameters + ---------- + a : `tree.ColumnExpression` + First operand. + b : `tree.ColumnExpression` + Second operand. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + match a, b: + case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference( + element=b_element + ): + return self.visit_spatial_join(a_element, b_element, flags) + case tree.DimensionFieldReference(element=element), region_expression: + pass + case region_expression, tree.DimensionFieldReference(element=element): + pass + case _: + raise AssertionError(f"Unexpected arguments for spatial overlap: {a}, {b}.") + if region := region_expression.get_literal_value(): + raise AssertionError(f"Unexpected argument for spatial overlap: {region_expression}.") + return self.visit_spatial_constraint(element, region, flags) + + def visit_temporal_overlap( + self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + """Dispatch a temporal overlap comparison predicate to handlers. + + This method should rarely (if ever) need to be overridden. + + Parameters + ---------- + a : `tree.ColumnExpression`- + First operand. + b : `tree.ColumnExpression` + Second operand. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + match a, b: + case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference( + element=b_element + ): + return self.visit_temporal_dimension_join(a_element, b_element, flags) + case _: + # We don't bother differentiating any other kind of temporal + # comparison, because in all foreseeable database schemas we + # wouldn't have to do anything special with them, since they + # don't participate in automatic join calculations and they + # should be straightforwardly convertible to SQL. + return None + + def visit_spatial_join( + self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + """Handle a spatial overlap comparison between two dimension elements. + + The default implementation updates the set of known spatial connections + (for use by `compute_automatic_spatial_joins`) and returns `None`. + + Parameters + ---------- + a : `DimensionElement` + One element in the join. + b : `DimensionElement` + The other element in the join. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + if a.spatial == b.spatial: + raise tree.InvalidQueryTreeError(f"Spatial join between {a} and {b} is not necessary.") + self._spatial_connections.merge( + cast(TopologicalFamily, a.spatial), cast(TopologicalFamily, b.spatial) + ) + return None + + def visit_spatial_constraint( + self, + element: DimensionElement, + region: Region, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + """Handle a spatial overlap comparison between a dimension element and + a literal region. + + The default implementation just returns `None`. + + Parameters + ---------- + element : `DimensionElement` + The dimension element in the comparison. + region : `lsst.sphgeom.Region` + The literal region in the comparison. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + return None + + def visit_temporal_dimension_join( + self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + """Handle a temporal overlap comparison between two dimension elements. + + The default implementation updates the set of known temporal + connections (for use by `compute_automatic_temporal_joins`) and returns + `None`. + + Parameters + ---------- + a : `DimensionElement` + One element in the join. + b : `DimensionElement` + The other element in the join. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + if a.temporal == b.temporal: + raise tree.InvalidQueryTreeError(f"Temporal join between {a} and {b} is not necessary.") + self._temporal_connections.merge( + cast(TopologicalFamily, a.temporal), cast(TopologicalFamily, b.temporal) + ) + return None diff --git a/python/lsst/daf/butler/queries/result_specs.py b/python/lsst/daf/butler/queries/result_specs.py new file mode 100644 index 0000000000..df5e07423e --- /dev/null +++ b/python/lsst/daf/butler/queries/result_specs.py @@ -0,0 +1,232 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "ResultSpecBase", + "DataCoordinateResultSpec", + "DimensionRecordResultSpec", + "DatasetRefResultSpec", +) + +from collections.abc import Mapping +from typing import Annotated, Literal, TypeAlias, Union, cast + +import pydantic + +from ..dimensions import DimensionElement, DimensionGroup +from .tree import ColumnSet, DatasetFieldName, InvalidQueryTreeError, OrderExpression, QueryTree + + +class ResultSpecBase(pydantic.BaseModel): + """Base class for all query-result specification objects.""" + + order_by: tuple[OrderExpression, ...] = () + """Expressions to sort the rows by.""" + + offset: int = 0 + """Index of the first row to return.""" + + limit: int | None = None + """Maximum number of rows to return, or `None` for no bound.""" + + def validate_tree(self, tree: QueryTree) -> None: + """Check that this result object is consistent with a query tree. + + Parameters + ---------- + tree : `QueryTree` + Query tree that defines the joins and row-filtering that these + results will come from. + """ + spec = cast(ResultSpec, self) + if not spec.dimensions <= tree.dimensions: + raise InvalidQueryTreeError( + f"Query result specification has dimensions {spec.dimensions} that are not a subset of the " + f"query's dimensions {tree.dimensions}." + ) + result_columns = spec.get_result_columns() + assert result_columns.dimensions == spec.dimensions, "enforced by ResultSpec implementations" + for dataset_type in result_columns.dataset_fields: + if dataset_type not in tree.datasets: + raise InvalidQueryTreeError(f"Dataset {dataset_type!r} is not available from this query.") + if not (tree.datasets[dataset_type].dimensions <= spec.dimensions): + raise InvalidQueryTreeError( + f"Result dataset type {dataset_type!r} has dimensions " + f"{tree.datasets[dataset_type].dimensions} that are not a subset of the result " + f"dimensions {spec.dimensions}." + ) + order_by_columns = ColumnSet(spec.dimensions) + for term in spec.order_by: + term.gather_required_columns(order_by_columns) + if not (order_by_columns.dimensions <= spec.dimensions): + raise InvalidQueryTreeError( + "Order-by expression may not reference columns that are not in the result dimensions." + ) + for dataset_type in order_by_columns.dataset_fields.keys(): + if dataset_type not in tree.datasets: + raise InvalidQueryTreeError( + f"Dataset type {dataset_type!r} in order-by expression is not part of the query." + ) + if not (tree.datasets[dataset_type].dimensions <= spec.dimensions): + raise InvalidQueryTreeError( + f"Dataset type {dataset_type!r} in order-by expression has dimensions " + f"{tree.datasets[dataset_type].dimensions} that are not a subset of the result " + f"dimensions {spec.dimensions}." + ) + + @property + def find_first_dataset(self) -> str | None: + return None + + +class DataCoordinateResultSpec(ResultSpecBase): + """Specification for a query that yields `DataCoordinate` objects.""" + + result_type: Literal["data_coordinate"] = "data_coordinate" + dimensions: DimensionGroup + include_dimension_records: bool + + def get_result_columns(self) -> ColumnSet: + """Return the columns included in the actual result rows. + + This does not necessarily include all columns required by the + `order_by` terms that are also a part of this spec. + """ + result = ColumnSet(self.dimensions) + if self.include_dimension_records: + for element_name in self.dimensions.elements: + element = self.dimensions.universe[element_name] + if not element.is_cached: + result.dimension_fields[element_name].update(element.schema.remainder.names) + return result + + +class DimensionRecordResultSpec(ResultSpecBase): + """Specification for a query that yields `DimensionRecord` objects.""" + + result_type: Literal["dimension_record"] = "dimension_record" + element: DimensionElement + + @property + def dimensions(self) -> DimensionGroup: + return self.element.minimal_group + + def get_result_columns(self) -> ColumnSet: + """Return the columns included in the actual result rows. + + This does not necessarily include all columns required by the + `order_by` terms that are also a part of this spec. + """ + result = ColumnSet(self.element.minimal_group) + result.dimension_fields[self.element.name].update(self.element.schema.remainder.names) + return result + + +class DatasetRefResultSpec(ResultSpecBase): + """Specification for a query that yields `DatasetRef` objects.""" + + result_type: Literal["dataset_ref"] = "dataset_ref" + dataset_type_name: str + dimensions: DimensionGroup + storage_class_name: str + include_dimension_records: bool + find_first: bool + + @property + def find_first_dataset(self) -> str | None: + return self.dataset_type_name if self.find_first else None + + def get_result_columns(self) -> ColumnSet: + """Return the columns included in the actual result rows. + + This does not necessarily include all columns required by the + `order_by` terms that are also a part of this spec. + """ + result = ColumnSet(self.dimensions) + result.dataset_fields[self.dataset_type_name].update({"dataset_id", "run"}) + if self.include_dimension_records: + for element_name in self.dimensions.elements: + element = self.dimensions.universe[element_name] + if not element.is_cached: + result.dimension_fields[element_name].update(element.schema.remainder.names) + return result + + +class GeneralResultSpec(ResultSpecBase): + """Specification for a query that yields a table with + an explicit list of columns. + """ + + result_type: Literal["general"] = "general" + dimensions: DimensionGroup + dimension_fields: Mapping[str, set[str]] + dataset_fields: Mapping[str, set[DatasetFieldName]] + find_first: bool + + @property + def find_first_dataset(self) -> str | None: + if self.find_first: + (dataset_type,) = self.dataset_fields.keys() + return dataset_type + return None + + def get_result_columns(self) -> ColumnSet: + """Return the columns included in the actual result rows. + + This does not necessarily include all columns required by the + `order_by` terms that are also a part of this spec. + """ + result = ColumnSet(self.dimensions) + for element_name, fields_for_element in self.dimension_fields.items(): + result.dimension_fields[element_name].update(fields_for_element) + for dataset_type, fields_for_dataset in self.dataset_fields.items(): + result.dataset_fields[dataset_type].update(fields_for_dataset) + return result + + @pydantic.model_validator(mode="after") + def _validate(self) -> GeneralResultSpec: + if self.find_first and len(self.dataset_fields) != 1: + raise InvalidQueryTreeError("find_first=True requires exactly one result dataset type.") + for element_name, fields_for_element in self.dimension_fields.items(): + if element_name not in self.dimensions.elements: + raise InvalidQueryTreeError(f"Dimension element {element_name} is not in {self.dimensions}.") + if not fields_for_element: + raise InvalidQueryTreeError( + f"Empty dimension element field set for {element_name!r} is not permitted." + ) + for dataset_type, fields_for_dataset in self.dataset_fields.items(): + if not fields_for_dataset: + raise InvalidQueryTreeError(f"Empty dataset field set for {dataset_type!r} is not permitted.") + return self + + +ResultSpec: TypeAlias = Annotated[ + Union[DataCoordinateResultSpec, DimensionRecordResultSpec, DatasetRefResultSpec, GeneralResultSpec], + pydantic.Field(discriminator="result_type"), +] diff --git a/python/lsst/daf/butler/queries/tree/__init__.py b/python/lsst/daf/butler/queries/tree/__init__.py new file mode 100644 index 0000000000..e320695f62 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/__init__.py @@ -0,0 +1,40 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from ._base import * +from ._column_expression import * +from ._column_literal import * +from ._column_reference import * +from ._column_set import * +from ._predicate import * +from ._predicate import LogicalNot +from ._query_tree import * + +LogicalNot.model_rebuild() +del LogicalNot + +Predicate.model_rebuild() diff --git a/python/lsst/daf/butler/queries/tree/_base.py b/python/lsst/daf/butler/queries/tree/_base.py new file mode 100644 index 0000000000..e07de32b25 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_base.py @@ -0,0 +1,252 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "QueryTreeBase", + "ColumnExpressionBase", + "DatasetFieldName", + "InvalidQueryTreeError", + "DATASET_FIELD_NAMES", +) + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeAlias, TypeVar, cast, get_args + +import pydantic + +if TYPE_CHECKING: + from ...column_spec import ColumnType + from ..visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor + from ._column_literal import ColumnLiteral + from ._column_set import ColumnSet + from ._predicate import PredicateLeaf + + +DatasetFieldName: TypeAlias = Literal["dataset_id", "ingest_date", "run", "collection", "timespan"] + +DATASET_FIELD_NAMES: tuple[DatasetFieldName, ...] = tuple(get_args(DatasetFieldName)) + +_T = TypeVar("_T") +_L = TypeVar("_L") +_A = TypeVar("_A") +_O = TypeVar("_O") + + +class InvalidQueryTreeError(RuntimeError): + """Exception raised when a query tree is or would not be valid.""" + + +class QueryTreeBase(pydantic.BaseModel): + """Base class for all non-primitive types in a query tree.""" + + model_config = pydantic.ConfigDict(frozen=True, extra="forbid", strict=True) + + +class ColumnExpressionBase(QueryTreeBase, ABC): + """Base class for objects that represent non-boolean column expressions in + a query tree. + + Notes + ----- + This is a closed hierarchy whose concrete, `~typing.final` derived classes + are members of the `ColumnExpression` union. That union should generally + be used in type annotations rather than the technically-open base class. + """ + + expression_type: str + + is_literal: ClassVar[bool] = False + """Whether this expression wraps a literal Python value.""" + + @property + @abstractmethod + def precedence(self) -> int: + """Operator precedence for this operation. + + Lower values bind more tightly, so parentheses are needed when printing + an expression where an operand has a higher value than the expression + itself. + """ + raise NotImplementedError() + + @property + @abstractmethod + def column_type(self) -> ColumnType: + """A string enumeration value representing the type of the column + expression. + """ + raise NotImplementedError() + + def get_literal_value(self) -> Any | None: + """Return the literal value wrapped by this expression, or `None` if + it is not a literal. + """ + return None + + @abstractmethod + def gather_required_columns(self, columns: ColumnSet) -> None: + """Add any columns required to evaluate this expression to the + given column set. + + Parameters + ---------- + columns : `ColumnSet` + Set of columns to modify in place. + """ + raise NotImplementedError() + + @abstractmethod + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + """Invoke the visitor interface. + + Parameters + ---------- + visitor : `ColumnExpressionVisitor` + Visitor to invoke a method on. + + Returns + ------- + result : `object` + Forwarded result from the visitor. + """ + raise NotImplementedError() + + +class ColumnLiteralBase(ColumnExpressionBase): + """Base class for objects that represent literal values as column + expressions in a query tree. + + Notes + ----- + This is a closed hierarchy whose concrete, `~typing.final` derived classes + are members of the `ColumnLiteral` union. That union should generally be + used in type annotations rather than the technically-open base class. The + concrete members of that union are only semi-public; they appear in the + serialized form of a column expression tree, but should only be constructed + via the `make_column_literal` factory function. All concrete members of + the union are also guaranteed to have a read-only ``value`` attribute + holding the wrapped literal, but it is unspecified whether that is a + regular attribute or a `property`. + """ + + is_literal: ClassVar[bool] = True + """Whether this expression wraps a literal Python value.""" + + @property + def precedence(self) -> int: + # Docstring inherited. + return 0 + + def get_literal_value(self) -> Any: + # Docstring inherited. + return cast("ColumnLiteral", self).value + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + pass + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return cast(ColumnType, self.expression_type) + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited + return visitor.visit_literal(cast("ColumnLiteral", self)) + + +class PredicateLeafBase(QueryTreeBase, ABC): + """Base class for leaf nodes of the `Predicate` tree. + + Notes + ----- + This is a closed hierarchy whose concrete, `~typing.final` derived classes + are members of the `PredicateLeaf` union. That union should generally be + used in type annotations rather than the technically-open base class. The + concrete members of that union are only semi-public; they appear in the + serialized form of a `Predicate`, but should only be constructed + via various `Predicate` factory methods. + """ + + @property + @abstractmethod + def precedence(self) -> int: + """Operator precedence for this operation. + + Lower values bind more tightly, so parentheses are needed when printing + an expression where an operand has a higher value than the expression + itself. + """ + raise NotImplementedError() + + @property + def column_type(self) -> Literal["bool"]: + """A string enumeration value representing the type of the column + expression. + """ + return "bool" + + @abstractmethod + def gather_required_columns(self, columns: ColumnSet) -> None: + """Add any columns required to evaluate this predicate leaf to the + given column set. + + Parameters + ---------- + columns : `ColumnSet` + Set of columns to modify in place. + """ + raise NotImplementedError() + + def invert(self) -> PredicateLeaf: + """Return a new leaf that is the logical not of this one.""" + from ._predicate import LogicalNot, LogicalNotOperand + + # This implementation works for every subclass other than LogicalNot + # itself, which overrides this method. + return LogicalNot.model_construct(operand=cast(LogicalNotOperand, self)) + + @abstractmethod + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + """Invoke the visitor interface. + + Parameters + ---------- + visitor : `PredicateVisitor` + Visitor to invoke a method on. + flags : `PredicateVisitFlags` + Flags that provide information about where this leaf appears in the + larger predicate tree. + + Returns + ------- + result : `object` + Forwarded result from the visitor. + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/queries/tree/_column_expression.py b/python/lsst/daf/butler/queries/tree/_column_expression.py new file mode 100644 index 0000000000..d792245aec --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_column_expression.py @@ -0,0 +1,257 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "ColumnExpression", + "OrderExpression", + "UnaryExpression", + "BinaryExpression", + "Reversed", + "UnaryOperator", + "BinaryOperator", +) + +from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, TypeVar, Union, final + +import pydantic + +from ...column_spec import ColumnType +from ._base import ColumnExpressionBase, InvalidQueryTreeError +from ._column_literal import ColumnLiteral +from ._column_reference import _ColumnReference +from ._column_set import ColumnSet + +if TYPE_CHECKING: + from ..visitors import ColumnExpressionVisitor + + +_T = TypeVar("_T") + + +UnaryOperator: TypeAlias = Literal["-", "begin_of", "end_of"] +BinaryOperator: TypeAlias = Literal["+", "-", "*", "/", "%"] + + +@final +class UnaryExpression(ColumnExpressionBase): + """A unary operation on a column expression that returns a non-bool.""" + + expression_type: Literal["unary"] = "unary" + + operand: ColumnExpression + """Expression this one operates on.""" + + operator: UnaryOperator + """Operator this expression applies.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.operand.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 1 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + match self.operator: + case "-": + return self.operand.column_type + case "begin_of" | "end_of": + return "datetime" + raise AssertionError(f"Invalid unary expression operator {self.operator}.") + + def __str__(self) -> str: + s = str(self.operand) + if self.operand.precedence >= self.precedence: + s = f"({s})" + match self.operator: + case "-": + return f"-{s}" + case "begin_of": + return f"{s}.begin" + case "end_of": + return f"{s}.end" + + @pydantic.model_validator(mode="after") + def _validate_types(self) -> UnaryExpression: + match (self.operator, self.operand.column_type): + case "-" "int" | "float": + pass + case ("begin_of" | "end_of", "timespan"): + pass + case _: + raise InvalidQueryTreeError( + f"Invalid column type {self.operand.column_type} for operator {self.operator!r}." + ) + return self + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_unary_expression(self) + + +@final +class BinaryExpression(ColumnExpressionBase): + """A binary operation on column expressions that returns a non-bool.""" + + expression_type: Literal["binary"] = "binary" + + a: ColumnExpression + """Left-hand side expression this one operates on.""" + + b: ColumnExpression + """Right-hand side expression this one operates on.""" + + operator: BinaryOperator + """Operator this expression applies. + + Integer '/' and '%' are defined as in SQL, not Python (though the + definitions are the same for positive arguments). + """ + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.a.gather_required_columns(columns) + self.b.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + match self.operator: + case "*" | "/" | "%": + return 2 + case "+" | "-": + return 3 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return self.a.column_type + + def __str__(self) -> str: + a = str(self.a) + b = str(self.b) + match self.operator: + case "*" | "+": + if self.a.precedence > self.precedence: + a = f"({a})" + if self.b.precedence > self.precedence: + b = f"({b})" + case _: + if self.a.precedence >= self.precedence: + a = f"({a})" + if self.b.precedence >= self.precedence: + b = f"({b})" + return f"({a} {self.operator} {b})" + + @pydantic.model_validator(mode="after") + def _validate_types(self) -> BinaryExpression: + if self.a.column_type != self.b.column_type: + raise InvalidQueryTreeError( + f"Column types for operator {self.operator} do not agree " + f"({self.a.column_type}, {self.b.column_type})." + ) + match (self.operator, self.a.column_type): + case ("+" | "-" | "*" | "/", "int" | "float"): + pass + case ("%", "int"): + pass + case _: + raise InvalidQueryTreeError( + f"Invalid column type {self.a.column_type} for operator {self.operator!r}." + ) + return self + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_binary_expression(self) + + +# Union without Pydantic annotation for the discriminator, for use in nesting +# in other unions that will add that annotation. It's not clear whether it +# would work to just nest the annotated ones, but it seems safest not to rely +# on undocumented behavior. +_ColumnExpression: TypeAlias = Union[ + ColumnLiteral, + _ColumnReference, + UnaryExpression, + BinaryExpression, +] + + +ColumnExpression: TypeAlias = Annotated[_ColumnExpression, pydantic.Field(discriminator="expression_type")] + + +@final +class Reversed(ColumnExpressionBase): + """A tag wrapper for `AbstractExpression` that indicate sorting in + reverse order. + """ + + expression_type: Literal["reversed"] = "reversed" + + operand: ColumnExpression + """Expression to sort on in reverse.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.operand.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return self.operand.precedence + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return self.operand.column_type + + def __str__(self) -> str: + return f"{self.operand} DESC" + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_reversed(self) + + +def _validate_order_expression(expression: _ColumnExpression | Reversed) -> _ColumnExpression | Reversed: + if expression.column_type not in ("int", "string", "float", "datetime"): + raise InvalidQueryTreeError(f"Column type {expression.column_type} of {expression} is not ordered.") + return expression + + +OrderExpression: TypeAlias = Annotated[ + Union[_ColumnExpression, Reversed], + pydantic.Field(discriminator="expression_type"), + pydantic.AfterValidator(_validate_order_expression), +] diff --git a/python/lsst/daf/butler/queries/tree/_column_literal.py b/python/lsst/daf/butler/queries/tree/_column_literal.py new file mode 100644 index 0000000000..17ef812b14 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_column_literal.py @@ -0,0 +1,372 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "ColumnLiteral", + "make_column_literal", +) + +import uuid +import warnings +from base64 import b64decode, b64encode +from functools import cached_property +from typing import Literal, TypeAlias, Union, final + +import astropy.time +import erfa +from lsst.sphgeom import Region + +from ..._timespan import Timespan +from ...time_utils import TimeConverter +from ._base import ColumnLiteralBase + +LiteralValue: TypeAlias = Union[int, str, float, bytes, uuid.UUID, astropy.time.Time, Timespan, Region] + + +@final +class IntColumnLiteral(ColumnLiteralBase): + """A literal `int` value in a column expression.""" + + expression_type: Literal["int"] = "int" + + value: int + """The wrapped value after base64 encoding.""" + + @classmethod + def from_value(cls, value: int) -> IntColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `int` + Value to wrap. + + Returns + ------- + expression : `IntColumnLiteral` + Literal expression object. + """ + return cls.model_construct(value=value) + + def __str__(self) -> str: + return repr(self.value) + + +@final +class StringColumnLiteral(ColumnLiteralBase): + """A literal `str` value in a column expression.""" + + expression_type: Literal["string"] = "string" + + value: str + """The wrapped value after base64 encoding.""" + + @classmethod + def from_value(cls, value: str) -> StringColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `str` + Value to wrap. + + Returns + ------- + expression : `StrColumnLiteral` + Literal expression object. + """ + return cls.model_construct(value=value) + + def __str__(self) -> str: + return repr(self.value) + + +@final +class FloatColumnLiteral(ColumnLiteralBase): + """A literal `float` value in a column expression.""" + + expression_type: Literal["float"] = "float" + + value: float + """The wrapped value after base64 encoding.""" + + @classmethod + def from_value(cls, value: float) -> FloatColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `float` + Value to wrap. + + Returns + ------- + expression : `FloatColumnLiteral` + Literal expression object. + """ + return cls.model_construct(value=value) + + def __str__(self) -> str: + return repr(self.value) + + +@final +class HashColumnLiteral(ColumnLiteralBase): + """A literal `bytes` value representing a hash in a column expression. + + The original value is base64-encoded when serialized and decoded on first + use. + """ + + expression_type: Literal["hash"] = "hash" + + encoded: bytes + """The wrapped value after base64 encoding.""" + + @cached_property + def value(self) -> bytes: + """The wrapped value.""" + return b64decode(self.encoded) + + @classmethod + def from_value(cls, value: bytes) -> HashColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `bytes` + Value to wrap. + + Returns + ------- + expression : `HashColumnLiteral` + Literal expression object. + """ + return cls.model_construct(encoded=b64encode(value)) + + def __str__(self) -> str: + return "(bytes)" + + +@final +class UUIDColumnLiteral(ColumnLiteralBase): + """A literal `uuid.UUID` value in a column expression.""" + + expression_type: Literal["uuid"] = "uuid" + + value: uuid.UUID + + @classmethod + def from_value(cls, value: uuid.UUID) -> UUIDColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `uuid.UUID` + Value to wrap. + + Returns + ------- + expression : `UUIDColumnLiteral` + Literal expression object. + """ + return cls.model_construct(value=value) + + def __str__(self) -> str: + return str(self.value) + + +@final +class DateTimeColumnLiteral(ColumnLiteralBase): + """A literal `astropy.time.Time` value in a column expression. + + The time is converted into TAI nanoseconds since 1970-01-01 when serialized + and restored from that on first use. + """ + + expression_type: Literal["datetime"] = "datetime" + + nsec: int + """TAI nanoseconds since 1970-01-01.""" + + @cached_property + def value(self) -> astropy.time.Time: + """The wrapped value.""" + return TimeConverter().nsec_to_astropy(self.nsec) + + @classmethod + def from_value(cls, value: astropy.time.Time) -> DateTimeColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `astropy.time.Time` + Value to wrap. + + Returns + ------- + expression : `DateTimeColumnLiteral` + Literal expression object. + """ + return cls.model_construct(nsec=TimeConverter().astropy_to_nsec(value)) + + def __str__(self) -> str: + # Trap dubious year warnings in case we have timespans from + # simulated data in the future + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=erfa.ErfaWarning) + return self.value.tai.strftime("%Y-%m-%dT%H:%M:%S") + + +@final +class TimespanColumnLiteral(ColumnLiteralBase): + """A literal `Timespan` value in a column expression. + + The timespan bounds are converted into TAI nanoseconds since 1970-01-01 + when serialized and the timespan is restored from that on first use. + """ + + expression_type: Literal["timespan"] = "timespan" + + begin_nsec: int + """TAI nanoseconds since 1970-01-01 for the lower bound of the timespan + (inclusive). + """ + + end_nsec: int + """TAI nanoseconds since 1970-01-01 for the upper bound of the timespan + (exclusive). + """ + + @cached_property + def value(self) -> astropy.time.Time: + """The wrapped value.""" + return Timespan(None, None, _nsec=(self.begin_nsec, self.end_nsec)) + + @classmethod + def from_value(cls, value: Timespan) -> TimespanColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `..Timespan` + Value to wrap. + + Returns + ------- + expression : `TimespanColumnLiteral` + Literal expression object. + """ + return cls.model_construct(begin_nsec=value._nsec[0], end_nsec=value._nsec[1]) + + def __str__(self) -> str: + return str(self.value) + + +@final +class RegionColumnLiteral(ColumnLiteralBase): + """A literal `lsst.sphgeom.Region` value in a column expression. + + The region is encoded to base64 `bytes` when serialized, and decoded on + first use. + """ + + expression_type: Literal["region"] = "region" + + encoded: bytes + """The wrapped value after base64 encoding.""" + + @cached_property + def value(self) -> bytes: + """The wrapped value.""" + return Region.decode(b64decode(self.encoded)) + + @classmethod + def from_value(cls, value: Region) -> RegionColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `..Region` + Value to wrap. + + Returns + ------- + expression : `RegionColumnLiteral` + Literal expression object. + """ + return cls.model_construct(encoded=b64encode(value.encode())) + + def __str__(self) -> str: + return "(region)" + + +ColumnLiteral: TypeAlias = Union[ + IntColumnLiteral, + StringColumnLiteral, + FloatColumnLiteral, + HashColumnLiteral, + UUIDColumnLiteral, + DateTimeColumnLiteral, + TimespanColumnLiteral, + RegionColumnLiteral, +] + + +def make_column_literal(value: LiteralValue) -> ColumnLiteral: + """Construct a `ColumnLiteral` from the value it will wrap. + + Parameters + ---------- + value : `LiteralValue` + Value to wrap. + + Returns + ------- + expression : `ColumnLiteral` + Literal expression object. + """ + match value: + case int(): + return IntColumnLiteral.from_value(value) + case str(): + return StringColumnLiteral.from_value(value) + case float(): + return FloatColumnLiteral.from_value(value) + case uuid.UUID(): + return UUIDColumnLiteral.from_value(value) + case bytes(): + return HashColumnLiteral.from_value(value) + case astropy.time.Time(): + return DateTimeColumnLiteral.from_value(value) + case Timespan(): + return TimespanColumnLiteral.from_value(value) + case Region(): + return RegionColumnLiteral.from_value(value) + raise TypeError(f"Invalid type {type(value).__name__} of value {value!r} for column literal.") diff --git a/python/lsst/daf/butler/queries/tree/_column_reference.py b/python/lsst/daf/butler/queries/tree/_column_reference.py new file mode 100644 index 0000000000..41434a6bcd --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_column_reference.py @@ -0,0 +1,179 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("ColumnReference", "DimensionKeyReference", "DimensionFieldReference", "DatasetFieldReference") + +from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, TypeVar, Union, final + +import pydantic + +from ...column_spec import ColumnType +from ...dimensions import Dimension, DimensionElement +from ._base import ColumnExpressionBase, DatasetFieldName, InvalidQueryTreeError + +if TYPE_CHECKING: + from ..visitors import ColumnExpressionVisitor + from ._column_set import ColumnSet + + +_T = TypeVar("_T") + + +@final +class DimensionKeyReference(ColumnExpressionBase): + """A column expression that references a dimension primary key column.""" + + expression_type: Literal["dimension_key"] = "dimension_key" + + dimension: Dimension + """Definition of this dimension.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + columns.update_dimensions(self.dimension.minimal_group) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 0 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return self.dimension.primary_key.type + + def __str__(self) -> str: + return self.dimension.name + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_dimension_key_reference(self) + + +@final +class DimensionFieldReference(ColumnExpressionBase): + """A column expression that references a dimension record column that is + not a primary key. + """ + + expression_type: Literal["dimension_field"] = "dimension_field" + + element: DimensionElement + """Definition of the dimension element.""" + + field: str + """Name of the field (i.e. column) in the element's logical table.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + columns.update_dimensions(self.element.minimal_group) + columns.dimension_fields[self.element.name].add(self.field) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 0 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return self.element.schema.remainder[self.field].type + + def __str__(self) -> str: + return f"{self.element}.{self.field}" + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_dimension_field_reference(self) + + @pydantic.model_validator(mode="after") + def _validate_field(self) -> DimensionFieldReference: + if self.field not in self.element.schema.remainder.names: + raise InvalidQueryTreeError(f"Dimension field {self.element.name}.{self.field} does not exist.") + return self + + +@final +class DatasetFieldReference(ColumnExpressionBase): + """A column expression that references a column associated with a dataset + type. + """ + + expression_type: Literal["dataset_field"] = "dataset_field" + + dataset_type: str + """Name of the dataset type to match any dataset type.""" + + field: DatasetFieldName + """Name of the field (i.e. column) in the dataset's logical table.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + columns.dataset_fields[self.dataset_type].add(self.field) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 0 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + match self.field: + case "dataset_id": + return "uuid" + case "ingest_date": + return "datetime" + case "run": + return "string" + case "collection": + return "string" + case "timespan": + return "timespan" + raise AssertionError(f"Invalid field {self.field!r} for dataset.") + + def __str__(self) -> str: + return f"{self.dataset_type}.{self.field}" + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_dataset_field_reference(self) + + +# Union without Pydantic annotation for the discriminator, for use in nesting +# in other unions that will add that annotation. It's not clear whether it +# would work to just nest the annotated ones, but it seems safest not to rely +# on undocumented behavior. +_ColumnReference: TypeAlias = Union[ + DimensionKeyReference, + DimensionFieldReference, + DatasetFieldReference, +] + +ColumnReference: TypeAlias = Annotated[_ColumnReference, pydantic.Field(discriminator="expression_type")] diff --git a/python/lsst/daf/butler/queries/tree/_column_set.py b/python/lsst/daf/butler/queries/tree/_column_set.py new file mode 100644 index 0000000000..582b68af28 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_column_set.py @@ -0,0 +1,186 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("ColumnSet",) + +from collections.abc import Iterable, Iterator, Mapping +from typing import Literal + +from ... import column_spec +from ...dimensions import DimensionGroup +from ...nonempty_mapping import NonemptyMapping +from ._base import DATASET_FIELD_NAMES, DatasetFieldName + + +class ColumnSet: + def __init__(self, dimensions: DimensionGroup) -> None: + self._dimensions = dimensions + self._removed_dimension_keys: set[str] = set() + self._dimension_fields: dict[str, set[str]] = {name: set() for name in dimensions.elements} + self._dataset_fields = NonemptyMapping[str, set[DatasetFieldName | Literal["collection_key"]]](set) + + @property + def dimensions(self) -> DimensionGroup: + return self._dimensions + + @property + def dimension_fields(self) -> Mapping[str, set[str]]: + return self._dimension_fields + + @property + def dataset_fields(self) -> Mapping[str, set[DatasetFieldName | Literal["collection_key"]]]: + return self._dataset_fields + + def __bool__(self) -> bool: + return bool(self._dimensions) or any(self._dataset_fields.values()) + + def issubset(self, other: ColumnSet) -> bool: + return ( + self._dimensions.issubset(other._dimensions) + and all( + fields.issubset(other._dimension_fields[element_name]) + for element_name, fields in self._dimension_fields.items() + ) + and all( + fields.issubset(other._dataset_fields.get(dataset_type, frozenset())) + for dataset_type, fields in self._dataset_fields.items() + ) + ) + + def issuperset(self, other: ColumnSet) -> bool: + return other.issubset(self) + + def isdisjoint(self, other: ColumnSet) -> bool: + # Note that if the dimensions are disjoint, the dimension fields are + # also disjoint, and if the dimensions are not disjoint, we already + # have our answer. The same is not true for dataset fields only for + # the edge case of dataset types with empty dimensions. + return self._dimensions.isdisjoint(other._dimensions) and ( + self._dataset_fields.keys().isdisjoint(other._dataset_fields) + or all( + fields.isdisjoint(other._dataset_fields[dataset_type]) + for dataset_type, fields in self._dataset_fields.items() + ) + ) + + def copy(self) -> ColumnSet: + result = ColumnSet(self._dimensions) + for element_name, element_fields in self._dimension_fields.items(): + result._dimension_fields[element_name].update(element_fields) + for dataset_type, dataset_fields in self._dataset_fields.items(): + result._dimension_fields[dataset_type].update(dataset_fields) + return result + + def update_dimensions(self, dimensions: DimensionGroup) -> None: + if not dimensions.issubset(self._dimensions): + self._dimensions = dimensions + self._dimension_fields = { + name: self._dimension_fields.get(name, set()) for name in self._dimensions.elements + } + + def update(self, other: ColumnSet) -> None: + self.update_dimensions(other.dimensions) + self._removed_dimension_keys.intersection_update(other._removed_dimension_keys) + for element_name, element_fields in other._dimension_fields.items(): + self._dimension_fields[element_name].update(element_fields) + for dataset_type, dataset_fields in other._dataset_fields.items(): + self._dataset_fields[dataset_type].update(dataset_fields) + + def drop_dimension_keys(self, names: Iterable[str]) -> ColumnSet: + self._removed_dimension_keys.update(names) + return self + + def drop_implied_dimension_keys(self) -> ColumnSet: + self._removed_dimension_keys.update(self._dimensions.implied) + return self + + def restore_dimension_keys(self) -> None: + self._removed_dimension_keys.clear() + + def __iter__(self) -> Iterator[tuple[str, str | None]]: + for dimension_name in self._dimensions.data_coordinate_keys: + if dimension_name not in self._removed_dimension_keys: + yield dimension_name, None + # We iterate over DimensionElements and their DimensionRecord columns + # in order to make sure that's predictable. We might want to extract + # these query results positionally in some contexts. + for element_name in self._dimensions.elements: + element = self._dimensions.universe[element_name] + fields_for_element = self._dimension_fields[element_name] + for spec in element.schema.remainder: + if spec.name in fields_for_element: + yield element_name, spec.name + # We sort dataset types and lexicographically just to keep our queries + # from having any dependence on set-iteration order. + for dataset_type in sorted(self._dataset_fields): + fields_for_dataset_type = self._dataset_fields[dataset_type] + for field in DATASET_FIELD_NAMES: + if field in fields_for_dataset_type: + yield dataset_type, field + + def is_timespan(self, logical_table: str, field: str | None) -> bool: + return field == "timespan" + + @staticmethod + def get_qualified_name(logical_table: str, field: str | None) -> str: + return logical_table if field is None else f"{logical_table}:{field}" + + def get_column_spec(self, logical_table: str, field: str | None) -> column_spec.ColumnSpec: + qualified_name = self.get_qualified_name(logical_table, field) + if field is None: + return self._dimensions.universe.dimensions[logical_table].primary_key.model_copy( + update=dict(name=qualified_name) + ) + if logical_table in self._dimension_fields: + return ( + self._dimensions.universe[logical_table] + .schema.all[field] + .model_copy(update=dict(name=qualified_name)) + ) + match field: + case "dataset_id": + return column_spec.UUIDColumnSpec.model_construct(name=qualified_name, nullable=False) + case "ingest_date": + return column_spec.DateTimeColumnSpec.model_construct(name=qualified_name) + case "run": + # TODO: string length matches the one defined in the + # CollectionManager implementations; we need to find a way to + # avoid hard-coding the value in multiple places. + return column_spec.StringColumnSpec.model_construct( + name=qualified_name, nullable=False, length=128 + ) + case "collection": + return column_spec.StringColumnSpec.model_construct( + name=qualified_name, nullable=False, length=128 + ) + case "rank": + return column_spec.IntColumnSpec.model_construct(name=qualified_name, nullable=False) + case "timespan": + return column_spec.TimespanColumnSpec.model_construct(name=qualified_name, nullable=False) + raise AssertionError(f"Unrecognized column identifiers: {logical_table}, {field}.") diff --git a/python/lsst/daf/butler/queries/tree/_predicate.py b/python/lsst/daf/butler/queries/tree/_predicate.py new file mode 100644 index 0000000000..b0d68bbbae --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_predicate.py @@ -0,0 +1,678 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "Predicate", + "PredicateLeaf", + "LogicalNotOperand", + "PredicateOperands", + "ComparisonOperator", +) + +import itertools +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Annotated, Iterable, Literal, TypeAlias, TypeVar, Union, cast, final + +import pydantic + +from ._base import InvalidQueryTreeError, QueryTreeBase +from ._column_expression import ColumnExpression + +if TYPE_CHECKING: + from ..visitors import PredicateVisitFlags, PredicateVisitor + from ._column_set import ColumnSet + from ._query_tree import QueryTree + +ComparisonOperator: TypeAlias = Literal["==", "!=", "<", ">", ">=", "<=", "overlaps"] + + +_L = TypeVar("_L") +_A = TypeVar("_A") +_O = TypeVar("_O") + + +class PredicateLeafBase(QueryTreeBase, ABC): + """Base class for leaf nodes of the `Predicate` tree. + + This is a closed hierarchy whose concrete, `~typing.final` derived classes + are members of the `PredicateLeaf` union. That union should generally + be used in type annotations rather than the technically-open base class. + """ + + @property + @abstractmethod + def precedence(self) -> int: + """Operator precedence for this operation. + + Lower values bind more tightly, so parentheses are needed when printing + an expression where an operand has a higher value than the expression + itself. + """ + raise NotImplementedError() + + @property + def column_type(self) -> Literal["bool"]: + """A string enumeration value representing the type of the column + expression. + """ + return "bool" + + @abstractmethod + def gather_required_columns(self, columns: ColumnSet) -> None: + """Add any columns required to evaluate this predicate leaf to the + given column set. + + Parameters + ---------- + columns : `ColumnSet` + Set of columns to modify in place. + """ + raise NotImplementedError() + + def invert(self) -> PredicateLeaf: + """Return a new leaf that is the logical not of this one.""" + return LogicalNot.model_construct(operand=cast("LogicalNotOperand", self)) + + @abstractmethod + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + """Invoke the visitor interface. + + Parameters + ---------- + visitor : `PredicateVisitor` + Visitor to invoke a method on. + flags : `PredicateVisitFlags` + Flags that provide information about where this leaf appears in the + larger predicate tree. + + Returns + ------- + result : `object` + Forwarded result from the visitor. + """ + raise NotImplementedError() + + +@final +class Predicate(QueryTreeBase): + """A boolean column expression. + + Notes + ----- + Predicate is the only class representing a boolean column expression that + should be used outside of this module (though the objects it nests appear + in its serialized form and hence are not fully private). It provides + several `classmethod` factories for constructing those nested types inside + a `Predicate` instance, and `PredicateVisitor` subclasses should be used + to process them. + """ + + operands: PredicateOperands + """Nested tuple of operands, with outer items combined via AND and inner + items combined via OR. + """ + + @property + def column_type(self) -> Literal["bool"]: + """A string enumeration value representing the type of the column + expression. + """ + return "bool" + + @classmethod + def from_bool(cls, value: bool) -> Predicate: + """Construct a predicate that always evaluates to `True` or `False`. + + Parameters + ---------- + value : `bool` + Value the predicate should evaluate to. + + Returns + ------- + predicate : `Predicate` + Predicate that evaluates to the given boolean value. + """ + return cls.model_construct(operands=() if value else ((),)) + + @classmethod + def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate: + """Construct a predicate representing a binary comparison between + two non-boolean column expressions. + + Parameters + ---------- + a : `ColumnExpression` + First column expression in the comparison. + operator : `str` + Enumerated string representing the comparison operator to apply. + May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps". + b : `ColumnExpression` + Second column expression in the comparison. + + Returns + ------- + predicate : `Predicate` + Predicate representing the comparison. + """ + return cls._from_leaf(Comparison.model_construct(a=a, operator=operator, b=b)) + + @classmethod + def is_null(cls, operand: ColumnExpression) -> Predicate: + """Construct a predicate that tests whether a column expression is + NULL. + + Parameters + ---------- + operand : `ColumnExpression` + Column expression to test. + + Returns + ------- + predicate : `Predicate` + Predicate representing the NULL check. + """ + return cls._from_leaf(IsNull.model_construct(operand=operand)) + + @classmethod + def in_container(cls, member: ColumnExpression, container: Iterable[ColumnExpression]) -> Predicate: + """Construct a predicate that tests whether one column expression is + a member of a container of other column expressions. + + Parameters + ---------- + member : `ColumnExpression` + Column expression that may be a member of the container. + container : `~collections.abc.Iterable` [ `ColumnExpression` ] + Container of column expressions to test for membership in. + + Returns + ------- + predicate : `Predicate` + Predicate representing the membership test. + """ + return cls._from_leaf(InContainer.model_construct(member=member, container=tuple(container))) + + @classmethod + def in_range( + cls, member: ColumnExpression, start: int = 0, stop: int | None = None, step: int = 1 + ) -> Predicate: + """Construct a predicate that tests whether an integer column + expression is part of a strided range. + + Parameters + ---------- + member : `ColumnExpression` + Column expression that may be a member of the range. + start : `int`, optional + Beginning of the range, inclusive. + stop : `int` or `None`, optional + End of the range, exclusive. + step : `int`, optional + Offset between values in the range. + + Returns + ------- + predicate : `Predicate` + Predicate representing the membership test. + """ + return cls._from_leaf(InRange.model_construct(member=member, start=start, stop=stop, step=step)) + + @classmethod + def in_query_tree( + cls, member: ColumnExpression, column: ColumnExpression, query_tree: QueryTree + ) -> Predicate: + """Construct a predicate that tests whether a column expression is + present in a single-column projection of a query tree. + + Parameters + ---------- + member : `ColumnExpression` + Column expression that may be present in the query. + column : `ColumnExpression` + Column to project from the query. + query_tree : `QueryTree` + Query tree to select from. + + Returns + ------- + predicate : `Predicate` + Predicate representing the membership test. + """ + return cls._from_leaf( + InQueryTree.model_construct(member=member, column=column, query_tree=query_tree) + ) + + def gather_required_columns(self, columns: ColumnSet) -> None: + """Add any columns required to evaluate this predicate to the given + column set. + + Parameters + ---------- + columns : `ColumnSet` + Set of columns to modify in place. + """ + for or_group in self.operands: + for operand in or_group: + operand.gather_required_columns(columns) + + def logical_and(self, *args: Predicate) -> Predicate: + """Construct a predicate representing the logical AND of this predicate + and one or more others. + + Parameters + ---------- + *args : `Predicate` + Other predicates. + + Returns + ------- + predicate : `Predicate` + Predicate representing the logical AND. + """ + operands = self.operands + for arg in args: + operands = self._impl_and(operands, arg.operands) + if not all(operands): + # If any item in operands is an empty tuple (i.e. False), simplify. + operands = ((),) + return Predicate.model_construct(operands=operands) + + def logical_or(self, *args: Predicate) -> Predicate: + """Construct a predicate representing the logical OR of this predicate + and one or more others. + + Parameters + ---------- + *args : `Predicate` + Other predicates. + + Returns + ------- + predicate : `Predicate` + Predicate representing the logical OR. + """ + operands = self.operands + for arg in args: + operands = self._impl_or(operands, arg.operands) + return Predicate.model_construct(operands=operands) + + def logical_not(self) -> Predicate: + """Construct a predicate representing the logical NOT of this + predicate. + + Returns + ------- + predicate : `Predicate` + Predicate representing the logical NOT. + """ + new_operands: PredicateOperands = ((),) + for or_group in self.operands: + new_group: PredicateOperands = () + for leaf in or_group: + new_group = self._impl_and(new_group, ((leaf.invert(),),)) + new_operands = self._impl_or(new_operands, new_group) + return Predicate.model_construct(operands=new_operands) + + def __str__(self) -> str: + and_terms = [] + for or_group in self.operands: + match len(or_group): + case 0: + and_terms.append("False") + case 1: + and_terms.append(str(or_group[0])) + case _: + and_terms.append(f"({' OR '.join(str(operand) for operand in or_group)})") + if not and_terms: + return "True" + return " AND ".join(and_terms) + + def visit(self, visitor: PredicateVisitor[_A, _O, _L]) -> _A: + """Invoke the visitor interface. + + Parameters + ---------- + visitor : `PredicateVisitor` + Visitor to invoke a method on. + + Returns + ------- + result : `object` + Forwarded result from the visitor. + """ + return visitor._visit_logical_and(self.operands) + + @classmethod + def _from_leaf(cls, leaf: PredicateLeaf) -> Predicate: + return cls._from_or_group((leaf,)) + + @classmethod + def _from_or_group(cls, or_group: tuple[PredicateLeaf, ...]) -> Predicate: + return Predicate.model_construct(operands=(or_group,)) + + @classmethod + def _impl_and(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: + return a + b + + @classmethod + def _impl_or(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: + return tuple([a_operand + b_operand for a_operand, b_operand in itertools.product(a, b)]) + + +@final +class LogicalNot(PredicateLeafBase): + """A boolean column expression that inverts its operand.""" + + predicate_type: Literal["not"] = "not" + + operand: LogicalNotOperand + """Upstream boolean expression to invert.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.operand.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 4 + + def __str__(self) -> str: + if self.operand.precedence <= self.precedence: + return f"NOT {self.operand}" + else: + return f"NOT ({self.operand})" + + def invert(self) -> LogicalNotOperand: + # Docstring inherited. + return self.operand + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor._visit_logical_not(self.operand, flags) + + +@final +class IsNull(PredicateLeafBase): + """A boolean column expression that tests whether its operand is NULL.""" + + predicate_type: Literal["is_null"] = "is_null" + + operand: ColumnExpression + """Upstream expression to test.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.operand.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 5 + + def __str__(self) -> str: + if self.operand.precedence <= self.precedence: + return f"{self.operand} IS NULL" + else: + return f"({self.operand}) IS NULL" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor.visit_is_null(self.operand, flags) + + +@final +class Comparison(PredicateLeafBase): + """A boolean columns expression formed by comparing two non-boolean + expressions. + """ + + predicate_type: Literal["comparison"] = "comparison" + + a: ColumnExpression + """Left-hand side expression for the comparison.""" + + b: ColumnExpression + """Right-hand side expression for the comparison.""" + + operator: ComparisonOperator + """Comparison operator.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.a.gather_required_columns(columns) + self.b.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 5 + + def __str__(self) -> str: + a = str(self.a) if self.a.precedence <= self.precedence else f"({self.a})" + b = str(self.b) if self.b.precedence <= self.precedence else f"({self.b})" + return f"{a} {self.operator.upper()} {b}" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor.visit_comparison(self.a, self.operator, self.b, flags) + + @pydantic.model_validator(mode="after") + def _validate_column_types(self) -> Comparison: + if self.a.column_type != self.b.column_type: + raise InvalidQueryTreeError( + f"Column types for comparison {self} do not agree " + f"({self.a.column_type}, {self.b.column_type})." + ) + match (self.operator, self.a.column_type): + case ("==" | "!=", _): + pass + case ("<" | ">" | ">=" | "<=", "int" | "string" | "float" | "datetime"): + pass + case ("overlaps", "region" | "timespan"): + pass + case _: + raise InvalidQueryTreeError( + f"Invalid column type {self.a.column_type} for operator {self.operator!r}." + ) + return self + + +@final +class InContainer(PredicateLeafBase): + """A boolean column expression that tests whether one expression is a + member of an explicit sequence of other expressions. + """ + + predicate_type: Literal["in_container"] = "in_container" + + member: ColumnExpression + """Expression to test for membership.""" + + container: tuple[ColumnExpression, ...] + """Expressions representing the elements of the container.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.member.gather_required_columns(columns) + for item in self.container: + item.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 5 + + def __str__(self) -> str: + m = str(self.member) if self.member.precedence <= self.precedence else f"({self.member})" + return f"{m} IN [{', '.join(str(item) for item in self.container)}]" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor.visit_in_container(self.member, self.container, flags) + + @pydantic.model_validator(mode="after") + def _validate(self) -> InContainer: + if self.member.column_type == "timespan" or self.member.column_type == "region": + raise InvalidQueryTreeError( + f"Timespan or region column {self.member} may not be used in IN expressions." + ) + if not all(item.column_type == self.member.column_type for item in self.container): + raise InvalidQueryTreeError(f"Column types for membership test {self} do not agree.") + return self + + +@final +class InRange(PredicateLeafBase): + """A boolean column expression that tests whether its expression is + included in an integer range. + """ + + predicate_type: Literal["in_range"] = "in_range" + + member: ColumnExpression + """Expression to test for membership.""" + + start: int = 0 + """Inclusive lower bound for the range.""" + + stop: int | None = None + """Exclusive upper bound for the range.""" + + step: int = 1 + """Difference between values in the range.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.member.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 5 + + def __str__(self) -> str: + s = f"{self.start if self.start else ''}..{self.stop if self.stop is not None else ''}" + if self.step != 1: + s = f"{s}:{self.step}" + m = str(self.member) if self.member.precedence <= self.precedence else f"({self.member})" + return f"{m} IN {s}" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + return visitor.visit_in_range(self.member, self.start, self.stop, self.step, flags) + + @pydantic.model_validator(mode="after") + def _validate(self) -> InRange: + if self.member.column_type != "int": + raise InvalidQueryTreeError(f"Column {self.member} is not an integer.") + return self + + +@final +class InQueryTree(PredicateLeafBase): + """A boolean column expression that tests whether its expression is + included single-column projection of a relation. + + This is primarily intended to be used on dataset ID columns, but it may + be useful for other columns as well. + """ + + predicate_type: Literal["in_relation"] = "in_relation" + + member: ColumnExpression + """Expression to test for membership.""" + + column: ColumnExpression + """Expression to extract from `query_tree`.""" + + query_tree: QueryTree + """Relation whose rows from `column` represent the container.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + # We're only gathering columns from the query_tree this predicate is + # attached to, not `self.column`, which belongs to `self.query_tree`. + self.member.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 5 + + def __str__(self) -> str: + m = str(self.member) if self.member.precedence <= self.precedence else f"({self.member})" + c = str(self.column) if self.column.precedence <= self.precedence else f"({self.column})" + return f"{m} IN [{{{self.query_tree}}}.{c}]" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor.visit_in_query_tree(self.member, self.column, self.query_tree, flags) + + @pydantic.model_validator(mode="after") + def _validate_column_types(self) -> InQueryTree: + if self.member.column_type == "timespan" or self.member.column_type == "region": + raise InvalidQueryTreeError( + f"Timespan or region column {self.member} may not be used in IN expressions." + ) + if self.member.column_type != self.column.column_type: + raise InvalidQueryTreeError( + f"Column types for membership test {self} do not agree " + f"({self.member.column_type}, {self.column.column_type})." + ) + + from ._column_set import ColumnSet + + columns_required_in_tree = ColumnSet(self.query_tree.dimensions) + self.column.gather_required_columns(columns_required_in_tree) + if columns_required_in_tree.dimensions != self.query_tree.dimensions: + raise InvalidQueryTreeError( + f"Column {self.column} requires dimensions {columns_required_in_tree.dimensions}, " + f"but query tree only has {self.query_tree.dimensions}." + ) + if not columns_required_in_tree.dataset_fields.keys() <= self.query_tree.datasets.keys(): + raise InvalidQueryTreeError( + f"Column {self.column} requires dataset types " + f"{set(columns_required_in_tree.dataset_fields.keys())} that are not present in query tree." + ) + return self + + +LogicalNotOperand: TypeAlias = Union[ + IsNull, + Comparison, + InContainer, + InRange, + InQueryTree, +] +PredicateLeaf: TypeAlias = Annotated[ + Union[LogicalNotOperand, LogicalNot], pydantic.Field(discriminator="predicate_type") +] + +PredicateOperands: TypeAlias = tuple[tuple[PredicateLeaf, ...], ...] diff --git a/python/lsst/daf/butler/queries/tree/_query_tree.py b/python/lsst/daf/butler/queries/tree/_query_tree.py new file mode 100644 index 0000000000..1f18c49889 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_query_tree.py @@ -0,0 +1,349 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "QueryTree", + "make_unit_query_tree", + "make_dimension_query_tree", + "DataCoordinateUploadKey", + "MaterializationKey", + "DatasetSearch", + "DeferredValidationQueryTree", +) + +import uuid +from collections.abc import Mapping +from functools import cached_property +from typing import TypeAlias, final + +import pydantic + +from ...dimensions import DimensionGroup, DimensionUniverse +from ...pydantic_utils import DeferredValidation +from ._base import InvalidQueryTreeError, QueryTreeBase +from ._column_set import ColumnSet +from ._predicate import Predicate + +DataCoordinateUploadKey: TypeAlias = uuid.UUID + +MaterializationKey: TypeAlias = uuid.UUID + + +def make_unit_query_tree(universe: DimensionUniverse) -> QueryTree: + """Make an initial query tree with empty dimensions and a single logical + row. + + This method should be used by `Butler._query` to construct the initial + query tree. This tree is a useful initial state because it is the + identity for joins, in that joining any other query tree to this + query tree yields that query tree. + + Parameters + ---------- + universe : `..DimensionUniverse` + Definitions for all dimensions. + + Returns + ------- + tree : `QueryTree` + A tree with empty dimensions. + """ + return make_dimension_query_tree(universe.empty.as_group()) + + +def make_dimension_query_tree(dimensions: DimensionGroup) -> QueryTree: + """Make an initial query tree with the given dimensions. + + Parameters + ---------- + dimensions : `..DimensionGroup` + Definitions for all dimensions. + + Returns + ------- + tree : `QueryTree` + A tree with the given dimensions. + """ + return QueryTree.model_construct(dimensions=dimensions) + + +@final +class DatasetSearch(QueryTreeBase): + """Information about a dataset search joined into a query tree. + + The dataset type name is the key of the dictionary (in `QueryTree`) where + this type is used as a value. + """ + + collections: tuple[str, ...] + """The collections to search. + + Order matters if this dataset type is later referenced by a `FindFirst` + operation. Collection wildcards are always resolved before being included + in a dataset search. + """ + + dimensions: DimensionGroup + """The dimensions of the dataset type. + + This must match the dimensions of the dataset type as already defined in + the butler database, but this cannot generally be verified when a relation + tree is validated (since it requires a database query) and hence must be + checked later. + """ + + +@final +class QueryTree(QueryTreeBase): + """A declarative, serializable description of a butler query. + + This class's attributes describe the columns that "available" to be + returned or used in ``where`` or ``order_by`` expressions, but it does not + carry information about the columns that are actually included in result + rows, or what kind of butler primitive (e.g. `DataCoordinate` or + `DatasetRef`) those rows might be transformed into. + """ + + dimensions: DimensionGroup + """The dimensions whose keys are joined into the query. + """ + + datasets: Mapping[str, DatasetSearch] = pydantic.Field(default_factory=dict) + """Dataset searches that have been joined into the query.""" + + data_coordinate_uploads: Mapping[DataCoordinateUploadKey, DimensionGroup] = pydantic.Field( + default_factory=dict + ) + """Uploaded tables of data ID values that have been joined into the query. + """ + + materializations: Mapping[MaterializationKey, DimensionGroup] = pydantic.Field(default_factory=dict) + """Tables of result rows from other queries that have been stored + temporarily on the server. + """ + + predicate: Predicate = Predicate.from_bool(True) + """Boolean expression trees whose logical AND defines a row filter.""" + + @cached_property + def join_operand_dimensions(self) -> frozenset[DimensionGroup]: + """A set of sets of the dimensions of all data coordinate uploads, + dataset searches, and materializations. + """ + result: set[DimensionGroup] = set(self.data_coordinate_uploads.values()) + result.update(self.materializations.values()) + for dataset_spec in self.datasets.values(): + result.add(dataset_spec.dimensions) + return frozenset(result) + + def join(self, other: QueryTree) -> QueryTree: + """Return a new tree that represents a join between ``self`` and + ``other``. + + Parameters + ---------- + other : `QueryTree` + Tree to join to this one. + + Returns + ------- + result : `QueryTree` + A new tree that joins ``self`` and ``other``. + + Raises + ------ + InvalidQueryTreeError + Raised if the join is ambiguous or otherwise invalid. + """ + if not self.datasets.keys().isdisjoint(other.datasets.keys()): + raise InvalidQueryTreeError( + "Cannot join when both sides include the same dataset type: " + f"{self.datasets.keys() & other.datasets.keys()}." + ) + return QueryTree.model_construct( + dimensions=self.dimensions | other.dimensions, + datasets={**self.datasets, **other.datasets}, + data_coordinate_uploads={**self.data_coordinate_uploads, **other.data_coordinate_uploads}, + materializations={**self.materializations, **other.materializations}, + predicate=self.predicate.logical_and(other.predicate), + ) + + def join_data_coordinate_upload( + self, key: DataCoordinateUploadKey, dimensions: DimensionGroup + ) -> QueryTree: + """Return a new tree that joins in an uploaded table of data ID values. + + Parameters + ---------- + key : `DataCoordinateUploadKey` + Unique identifier for this upload, as assigned by a `QueryDriver`. + dimensions : `DimensionGroup` + Dimensions of the data IDs. + + Returns + ------- + result : `QueryTree` + A new tree that joins in the data ID table. + """ + if key in self.data_coordinate_uploads: + assert ( + dimensions == self.data_coordinate_uploads[key] + ), f"Different dimensions for the same data coordinate upload key {key}!" + return self + data_coordinate_uploads = dict(self.data_coordinate_uploads) + data_coordinate_uploads[key] = dimensions + return self.model_copy( + update=dict( + dimensions=self.dimensions | dimensions, data_coordinate_uploads=data_coordinate_uploads + ) + ) + + def join_materialization(self, key: MaterializationKey, dimensions: DimensionGroup) -> QueryTree: + """Return a new tree that joins in temporarily stored results from + another query. + + Parameters + ---------- + key : `MaterializationKey` + Unique identifier for this materialization, as assigned by a + `QueryDriver`. + dimensions : `DimensionGroup` + The dimensions stored in the materialization. + + Returns + ------- + result : `QueryTree` + A new tree that joins in the materialization. + """ + if key in self.materializations: + assert ( + dimensions == self.materializations[key] + ), f"Different dimensions for the same materialization {key}!" + return self + materializations = dict(self.materializations) + materializations[key] = dimensions + return self.model_copy( + update=dict(dimensions=self.dimensions | dimensions, materializations=materializations) + ) + + def join_dataset(self, dataset_type: str, spec: DatasetSearch) -> QueryTree: + """Return a new tree joins in a search for a dataset. + + Parameters + ---------- + dataset_type : `str` + Name of dataset type to join in. + spec : `DatasetSpec` + Struct containing the collection search path and dataset type + dimensions. + + Returns + ------- + result : `QueryTree` + A new tree that joins in the dataset search. + + Raises + ------ + InvalidQueryTreeError + Raised if this dataset type is already present in the query tree. + """ + if dataset_type in self.datasets: + if spec != self.datasets[dataset_type]: + raise InvalidQueryTreeError( + f"Dataset type {dataset_type!r} is already present in the query, with different " + "collections and/or dimensions." + ) + return self + datasets = dict(self.datasets) + datasets[dataset_type] = spec + return self.model_copy(update=dict(dimensions=self.dimensions | spec.dimensions, datasets=datasets)) + + def where(self, *terms: Predicate) -> QueryTree: + """Return a new tree that adds row filtering via a boolean column + expression. + + Parameters + ---------- + *terms : `Predicate` + Boolean column expressions that filter rows. Arguments are + combined with logical AND. + + Returns + ------- + result : `QueryTree` + A new tree that with row filtering. + + Raises + ------ + InvalidQueryTreeError + Raised if a column expression requires a dataset column that is not + already present in the query tree. + + Notes + ----- + If an expression references a dimension or dimension element that is + not already present in the query tree, it will be joined in, but + datasets must already be joined into a query tree in order to reference + their fields in expressions. + """ + where_predicate = self.predicate + columns = ColumnSet(self.dimensions) + for where_term in terms: + where_term.gather_required_columns(columns) + where_predicate = where_predicate.logical_and(where_term) + if not (columns.dataset_fields.keys() <= self.datasets.keys()): + raise InvalidQueryTreeError( + f"Cannot reference dataset type(s) {columns.dataset_fields.keys() - self.datasets.keys()} " + "that have not been joined." + ) + return self.model_copy(update=dict(dimensions=columns.dimensions, where_predicate=where_predicate)) + + @pydantic.model_validator(mode="after") + def _validate_join_operands(self) -> QueryTree: + for dimensions in self.join_operand_dimensions: + if not dimensions.issubset(self.dimensions): + raise InvalidQueryTreeError( + f"Dimensions {dimensions} of join operand are not a " + f"subset of the query tree's dimensions {self.dimensions}." + ) + return self + + @pydantic.model_validator(mode="after") + def _validate_required_columns(self) -> QueryTree: + columns = ColumnSet(self.dimensions) + self.predicate.gather_required_columns(columns) + if not columns.dimensions.issubset(self.dimensions): + raise InvalidQueryTreeError("Predicate requires dimensions beyond those in the query tree.") + if not columns.dataset_fields.keys() <= self.datasets.keys(): + raise InvalidQueryTreeError("Predicate requires dataset columns that are not in the query tree.") + return self + + +class DeferredValidationQueryTree(DeferredValidation[QueryTree]): + pass diff --git a/python/lsst/daf/butler/queries/visitors.py b/python/lsst/daf/butler/queries/visitors.py new file mode 100644 index 0000000000..8340389e19 --- /dev/null +++ b/python/lsst/daf/butler/queries/visitors.py @@ -0,0 +1,540 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "ColumnExpressionVisitor", + "PredicateVisitor", + "SimplePredicateVisitor", + "PredicateVisitFlags", +) + +import enum +from abc import abstractmethod +from typing import Generic, TypeVar, final + +from . import tree + +_T = TypeVar("_T") +_L = TypeVar("_L") +_A = TypeVar("_A") +_O = TypeVar("_O") + + +class PredicateVisitFlags(enum.Flag): + """Flags that provide information about the location of a predicate term + in the larger tree. + """ + + HAS_AND_SIBLINGS = enum.auto() + HAS_OR_SIBLINGS = enum.auto() + INVERTED = enum.auto() + + +class ColumnExpressionVisitor(Generic[_T]): + """A visitor interface for traversing a `ColumnExpression` tree. + + Notes + ----- + Unlike `Predicate`, the concrete column expression types need to be + public for various reasons, and hence the visitor interface uses them + directly in its arguments. + + This interface includes `Reversed` (which is part of the `OrderExpression` + union but not the `ColumnExpression` union) because it is simpler to have + just one visitor interface and disable support for it at runtime as + appropriate. + """ + + @abstractmethod + def visit_literal(self, expression: tree.ColumnLiteral) -> _T: + """Visit a column expression that wraps a literal value. + + Parameters + ---------- + expression : `tree.ColumnLiteral` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_dimension_key_reference(self, expression: tree.DimensionKeyReference) -> _T: + """Visit a column expression that represents a dimension column. + + Parameters + ---------- + expression : `tree.DimensionKeyReference` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_dimension_field_reference(self, expression: tree.DimensionFieldReference) -> _T: + """Visit a column expression that represents a dimension record field. + + Parameters + ---------- + expression : `tree.DimensionFieldReference` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_dataset_field_reference(self, expression: tree.DatasetFieldReference) -> _T: + """Visit a column expression that represents a dataset field. + + Parameters + ---------- + expression : `tree.DatasetFieldReference` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_unary_expression(self, expression: tree.UnaryExpression) -> _T: + """Visit a column expression that represents a unary operation. + + Parameters + ---------- + expression : `tree.UnaryExpression` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_binary_expression(self, expression: tree.BinaryExpression) -> _T: + """Visit a column expression that wraps a binary operation. + + Parameters + ---------- + expression : `tree.BinaryExpression` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_reversed(self, expression: tree.Reversed) -> _T: + """Visit a column expression that switches sort order from ascending + to descending. + + Parameters + ---------- + expression : `tree.Reversed` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + +class PredicateVisitor(Generic[_A, _O, _L]): + """A visitor interface for traversing a `Predicate`. + + Notes + ----- + The concrete `PredicateLeaf` types are only semi-public (they appear in + the serialized form of a `Predicate`, but their types should not generally + be referenced directly outside of the module in which they are defined. + As a result, visiting these objects unpacks their attributes into the + visit method arguments. + """ + + @abstractmethod + def visit_comparison( + self, + a: tree.ColumnExpression, + operator: tree.ComparisonOperator, + b: tree.ColumnExpression, + flags: PredicateVisitFlags, + ) -> _L: + """Visit a binary comparison between column expressions. + + Parameters + ---------- + a : `tree.ColumnExpression` + First column expression in the comparison. + operator : `str` + Enumerated string representing the comparison operator to apply. + May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps". + b : `tree.ColumnExpression` + Second column expression in the comparison. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_is_null(self, operand: tree.ColumnExpression, flags: PredicateVisitFlags) -> _L: + """Visit a predicate leaf that tests whether a column expression is + NULL. + + Parameters + ---------- + operand : `tree.ColumnExpression` + Column expression to test. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_in_container( + self, + member: tree.ColumnExpression, + container: tuple[tree.ColumnExpression, ...], + flags: PredicateVisitFlags, + ) -> _L: + """Visit a predicate leaf that tests whether a column expression is + a member of a container. + + Parameters + ---------- + member : `tree.ColumnExpression` + Column expression that may be a member of the container. + container : `~collections.abc.Iterable` [ `tree.ColumnExpression` ] + Container of column expressions to test for membership in. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_in_range( + self, + member: tree.ColumnExpression, + start: int, + stop: int | None, + step: int, + flags: PredicateVisitFlags, + ) -> _L: + """Visit a predicate leaf that tests whether a column expression is + a member of an integer range. + + Parameters + ---------- + member : `tree.ColumnExpression` + Column expression that may be a member of the range. + start : `int`, optional + Beginning of the range, inclusive. + stop : `int` or `None`, optional + End of the range, exclusive. + step : `int`, optional + Offset between values in the range. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_in_query_tree( + self, + member: tree.ColumnExpression, + column: tree.ColumnExpression, + query_tree: tree.QueryTree, + flags: PredicateVisitFlags, + ) -> _L: + """Visit a predicate leaf that tests whether a column expression is + a member of a container. + + Parameters + ---------- + member : `tree.ColumnExpression` + Column expression that may be present in the query. + column : `tree.ColumnExpression` + Column to project from the query. + query_tree : `QueryTree` + Query tree to select from. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def apply_logical_not(self, original: tree.PredicateLeaf, result: _L, flags: PredicateVisitFlags) -> _L: + """Apply a logical NOT to the result of visiting an inverted predicate + leaf. + + Parameters + ---------- + original : `PredicateLeaf` + The original operand of the logical NOT operation. + result : `object` + Implementation-defined result of visiting the operand. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. Never has `PredicateVisitFlags.INVERTED` set. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def apply_logical_or( + self, + originals: tuple[tree.PredicateLeaf, ...], + results: tuple[_L, ...], + flags: PredicateVisitFlags, + ) -> _O: + """Apply a logical OR operation to the result of visiting a `tuple` of + predicate leaf objects. + + Parameters + ---------- + originals : `tuple` [ `PredicateLeaf`, ... ] + Original leaf objects in the logical OR. + results : `tuple` [ `object`, ... ] + Result of visiting the leaf objects. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. Never has `PredicateVisitFlags.INVERTED` or + `PredicateVisitFlags.HAS_OR_SIBLINGS` set. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def apply_logical_and(self, originals: tree.PredicateOperands, results: tuple[_O, ...]) -> _A: + """Apply a logical AND operation to the result of visiting a nested + `tuple` of predicate leaf objects. + + Parameters + ---------- + originals : `tuple` [ `tuple` [ `PredicateLeaf`, ... ], ... ] + Nested tuple of predicate leaf objects, with inner tuples + corresponding to groups that should be combined with logical OR. + results : `tuple` [ `object`, ... ] + Result of visiting the leaf objects. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @final + def _visit_logical_not(self, operand: tree.LogicalNotOperand, flags: PredicateVisitFlags) -> _L: + return self.apply_logical_not( + operand, operand.visit(self, flags | PredicateVisitFlags.INVERTED), flags + ) + + @final + def _visit_logical_or(self, operands: tuple[tree.PredicateLeaf, ...], flags: PredicateVisitFlags) -> _O: + nested_flags = flags + if len(operands) > 1: + nested_flags |= PredicateVisitFlags.HAS_OR_SIBLINGS + return self.apply_logical_or( + operands, tuple([operand.visit(self, nested_flags) for operand in operands]), flags + ) + + @final + def _visit_logical_and(self, operands: tree.PredicateOperands) -> _A: + if len(operands) > 1: + nested_flags = PredicateVisitFlags.HAS_AND_SIBLINGS + else: + nested_flags = PredicateVisitFlags(0) + return self.apply_logical_and( + operands, tuple([self._visit_logical_or(or_group, nested_flags) for or_group in operands]) + ) + + +class SimplePredicateVisitor( + PredicateVisitor[tree.Predicate | None, tree.Predicate | None, tree.Predicate | None] +): + """An intermediate base class for predicate visitor implementations that + either return `None` or a new `Predicate`. + + Notes + ----- + This class implements all leaf-node visitation methods to return `None`, + which is interpreted by the ``apply*`` method implementations as indicating + that the leaf is unmodified. Subclasses can thus override only certain + visitation methods and either return `None` if there is no result, or + return a replacement `Predicate` to construct a new tree. + """ + + def visit_comparison( + self, + a: tree.ColumnExpression, + operator: tree.ComparisonOperator, + b: tree.ColumnExpression, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def visit_is_null( + self, operand: tree.ColumnExpression, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def visit_in_container( + self, + member: tree.ColumnExpression, + container: tuple[tree.ColumnExpression, ...], + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def visit_in_range( + self, + member: tree.ColumnExpression, + start: int, + stop: int | None, + step: int, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def visit_in_query_tree( + self, + member: tree.ColumnExpression, + column: tree.ColumnExpression, + query_tree: tree.QueryTree, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def apply_logical_not( + self, original: tree.PredicateLeaf, result: tree.Predicate | None, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + # Docstring inherited. + if result is None: + return None + from . import tree + + return tree.Predicate._from_leaf(original).logical_not() + + def apply_logical_or( + self, + originals: tuple[tree.PredicateLeaf, ...], + results: tuple[tree.Predicate | None, ...], + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + if all(result is None for result in results): + return None + from . import tree + + return tree.Predicate.from_bool(False).logical_or( + *[ + tree.Predicate._from_leaf(original) if result is None else result + for original, result in zip(originals, results) + ] + ) + + def apply_logical_and( + self, + originals: tree.PredicateOperands, + results: tuple[tree.Predicate | None, ...], + ) -> tree.Predicate | None: + # Docstring inherited. + if all(result is None for result in results): + return None + from . import tree + + return tree.Predicate.from_bool(True).logical_and( + *[ + tree.Predicate._from_or_group(original) if result is None else result + for original, result in zip(originals, results) + ] + )