From e090ce436513a411d58e7324af84f5bcd05c7cdb Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Tue, 30 Jan 2024 17:49:15 -0500 Subject: [PATCH] Prototyping new query system. --- python/lsst/daf/butler/_query_results.py | 5 + .../butler/direct_query_driver/__init__.py | 29 + .../direct_query_driver/_analyzed_query.py | 170 ++++ .../direct_query_driver/_convert_results.py | 61 ++ .../daf/butler/direct_query_driver/_driver.py | 789 ++++++++++++++++++ .../direct_query_driver/_postprocessing.py | 138 +++ .../direct_query_driver/_sql_builder.py | 257 ++++++ .../_sql_column_visitor.py | 239 ++++++ python/lsst/daf/butler/queries/__init__.py | 32 + python/lsst/daf/butler/queries/_base.py | 195 +++++ .../queries/_data_coordinate_query_results.py | 142 ++++ .../butler/queries/_dataset_query_results.py | 232 +++++ .../_dimension_record_query_results.py | 109 +++ python/lsst/daf/butler/queries/_query.py | 511 ++++++++++++ .../lsst/daf/butler/queries/convert_args.py | 244 ++++++ python/lsst/daf/butler/queries/driver.py | 512 ++++++++++++ .../daf/butler/queries/expression_factory.py | 428 ++++++++++ python/lsst/daf/butler/queries/overlaps.py | 466 +++++++++++ .../lsst/daf/butler/queries/result_specs.py | 232 +++++ .../lsst/daf/butler/queries/tree/__init__.py | 40 + python/lsst/daf/butler/queries/tree/_base.py | 252 ++++++ .../butler/queries/tree/_column_expression.py | 257 ++++++ .../butler/queries/tree/_column_literal.py | 372 +++++++++ .../butler/queries/tree/_column_reference.py | 179 ++++ .../daf/butler/queries/tree/_column_set.py | 186 +++++ .../daf/butler/queries/tree/_predicate.py | 678 +++++++++++++++ .../daf/butler/queries/tree/_query_tree.py | 349 ++++++++ python/lsst/daf/butler/queries/visitors.py | 540 ++++++++++++ .../butler/registry/collections/nameKey.py | 6 + .../registry/collections/synthIntKey.py | 11 + .../datasets/byDimensions/_storage.py | 212 ++++- .../daf/butler/registry/dimensions/static.py | 198 ++++- .../registry/interfaces/_collections.py | 26 + .../butler/registry/interfaces/_datasets.py | 16 +- .../butler/registry/interfaces/_dimensions.py | 19 +- 35 files changed, 8123 insertions(+), 9 deletions(-) create mode 100644 python/lsst/daf/butler/direct_query_driver/__init__.py create mode 100644 python/lsst/daf/butler/direct_query_driver/_analyzed_query.py create mode 100644 python/lsst/daf/butler/direct_query_driver/_convert_results.py create mode 100644 python/lsst/daf/butler/direct_query_driver/_driver.py create mode 100644 python/lsst/daf/butler/direct_query_driver/_postprocessing.py create mode 100644 python/lsst/daf/butler/direct_query_driver/_sql_builder.py create mode 100644 python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py create mode 100644 python/lsst/daf/butler/queries/__init__.py create mode 100644 python/lsst/daf/butler/queries/_base.py create mode 100644 python/lsst/daf/butler/queries/_data_coordinate_query_results.py create mode 100644 python/lsst/daf/butler/queries/_dataset_query_results.py create mode 100644 python/lsst/daf/butler/queries/_dimension_record_query_results.py create mode 100644 python/lsst/daf/butler/queries/_query.py create mode 100644 python/lsst/daf/butler/queries/convert_args.py create mode 100644 python/lsst/daf/butler/queries/driver.py create mode 100644 python/lsst/daf/butler/queries/expression_factory.py create mode 100644 python/lsst/daf/butler/queries/overlaps.py create mode 100644 python/lsst/daf/butler/queries/result_specs.py create mode 100644 python/lsst/daf/butler/queries/tree/__init__.py create mode 100644 python/lsst/daf/butler/queries/tree/_base.py create mode 100644 python/lsst/daf/butler/queries/tree/_column_expression.py create mode 100644 python/lsst/daf/butler/queries/tree/_column_literal.py create mode 100644 python/lsst/daf/butler/queries/tree/_column_reference.py create mode 100644 python/lsst/daf/butler/queries/tree/_column_set.py create mode 100644 python/lsst/daf/butler/queries/tree/_predicate.py create mode 100644 python/lsst/daf/butler/queries/tree/_query_tree.py create mode 100644 python/lsst/daf/butler/queries/visitors.py diff --git a/python/lsst/daf/butler/_query_results.py b/python/lsst/daf/butler/_query_results.py index 30c18327e9..fe35561a04 100644 --- a/python/lsst/daf/butler/_query_results.py +++ b/python/lsst/daf/butler/_query_results.py @@ -562,6 +562,11 @@ def dataset_type(self) -> DatasetType: """ raise NotImplementedError() + @property + def dimensions(self) -> DimensionGroup: + """The dimensions of the dataset type returned by this query.""" + return self.dataset_type.dimensions.as_group() + @property @abstractmethod def data_ids(self) -> DataCoordinateQueryResults: diff --git a/python/lsst/daf/butler/direct_query_driver/__init__.py b/python/lsst/daf/butler/direct_query_driver/__init__.py new file mode 100644 index 0000000000..d8aae48e4b --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/__init__.py @@ -0,0 +1,29 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from ._postprocessing import Postprocessing +from ._sql_builder import SqlBuilder diff --git a/python/lsst/daf/butler/direct_query_driver/_analyzed_query.py b/python/lsst/daf/butler/direct_query_driver/_analyzed_query.py new file mode 100644 index 0000000000..71d98c7898 --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_analyzed_query.py @@ -0,0 +1,170 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("AnalyzedQuery", "AnalyzedDatasetSearch", "DataIdExtractionVisitor") + +import dataclasses +from collections.abc import Iterator +from typing import Any + +from ..dimensions import DataIdValue, DimensionElement, DimensionGroup, DimensionUniverse +from ..queries import tree as qt +from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, SimplePredicateVisitor +from ..registry.interfaces import CollectionRecord +from ._postprocessing import Postprocessing + + +@dataclasses.dataclass +class AnalyzedDatasetSearch: + name: str + shrunk: str + dimensions: DimensionGroup + collection_records: list[CollectionRecord] = dataclasses.field(default_factory=list) + messages: list[str] = dataclasses.field(default_factory=list) + is_calibration_search: bool = False + + +@dataclasses.dataclass +class AnalyzedQuery: + predicate: qt.Predicate + postprocessing: Postprocessing + base_columns: qt.ColumnSet + projection_columns: qt.ColumnSet + final_columns: qt.ColumnSet + find_first_dataset: str | None + materializations: dict[qt.MaterializationKey, DimensionGroup] = dataclasses.field(default_factory=dict) + datasets: dict[str, AnalyzedDatasetSearch] = dataclasses.field(default_factory=dict) + messages: list[str] = dataclasses.field(default_factory=list) + constraint_data_id: dict[str, DataIdValue] = dataclasses.field(default_factory=dict) + data_coordinate_uploads: dict[qt.DataCoordinateUploadKey, DimensionGroup] = dataclasses.field( + default_factory=dict + ) + needs_dimension_distinct: bool = False + needs_find_first_resolution: bool = False + projection_region_aggregates: list[DimensionElement] = dataclasses.field(default_factory=list) + + @property + def universe(self) -> DimensionUniverse: + return self.base_columns.dimensions.universe + + @property + def needs_projection(self) -> bool: + return self.needs_dimension_distinct or self.postprocessing.check_validity_match_count + + def iter_mandatory_base_elements(self) -> Iterator[DimensionElement]: + for element_name in self.base_columns.dimensions.elements: + element = self.universe[element_name] + if self.base_columns.dimension_fields[element_name]: + # We need to get dimension record fields for this element, and + # its table is the only place to get those. + yield element + elif element.defines_relationships: + # We als need to join in DimensionElements tables that define + # one-to-many and many-to-many relationships, but data + # coordinate uploads, materializations, and datasets can also + # provide these relationships. Data coordinate uploads and + # dataset tables only have required dimensions, and can hence + # only provide relationships involving those. + if any( + element.minimal_group.names <= upload_dimensions.required + for upload_dimensions in self.data_coordinate_uploads.values() + ): + continue + if any( + element.minimal_group.names <= dataset_spec.dimensions.required + for dataset_spec in self.datasets.values() + ): + continue + # Materializations have all key columns for their dimensions. + if any( + element in materialization_dimensions.names + for materialization_dimensions in self.materializations.values() + ): + continue + yield element + + +class DataIdExtractionVisitor( + SimplePredicateVisitor, + ColumnExpressionVisitor[tuple[str, None] | tuple[None, Any] | tuple[None, None]], +): + def __init__(self, data_id: dict[str, DataIdValue], messages: list[str]): + self.data_id = data_id + self.messages = messages + + def visit_comparison( + self, + a: qt.ColumnExpression, + operator: qt.ComparisonOperator, + b: qt.ColumnExpression, + flags: PredicateVisitFlags, + ) -> None: + if flags & PredicateVisitFlags.HAS_OR_SIBLINGS: + return None + if flags & PredicateVisitFlags.INVERTED: + if operator == "!=": + operator = "==" + else: + return None + if operator != "==": + return None + k_a, v_a = a.visit(self) + k_b, v_b = b.visit(self) + if k_a is not None and v_b is not None: + key = k_a + value = v_b + elif k_b is not None and v_a is not None: + key = k_b + value = v_a + else: + return None + if (old := self.data_id.setdefault(key, value)) != value: + self.messages.append(f"'where' expression requires both {key}={value!r} and {key}={old!r}.") + return None + + def visit_binary_expression(self, expression: qt.BinaryExpression) -> tuple[None, None]: + return None, None + + def visit_unary_expression(self, expression: qt.UnaryExpression) -> tuple[None, None]: + return None, None + + def visit_literal(self, expression: qt.ColumnLiteral) -> tuple[None, Any]: + return None, expression.get_literal_value() + + def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> tuple[str, None]: + return expression.dimension.name, None + + def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> tuple[None, None]: + return None, None + + def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> tuple[None, None]: + return None, None + + def visit_reversed(self, expression: qt.Reversed) -> tuple[None, None]: + raise AssertionError("No Reversed expressions in predicates.") diff --git a/python/lsst/daf/butler/direct_query_driver/_convert_results.py b/python/lsst/daf/butler/direct_query_driver/_convert_results.py new file mode 100644 index 0000000000..899c1760b4 --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_convert_results.py @@ -0,0 +1,61 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("convert_dimension_record_results",) + +from collections.abc import Iterable +from typing import TYPE_CHECKING + +import sqlalchemy + +from ..dimensions import DimensionRecordSet + +if TYPE_CHECKING: + from ..queries.driver import DimensionRecordResultPage, PageKey + from ..queries.result_specs import DimensionRecordResultSpec + from ..registry.nameShrinker import NameShrinker + + +def convert_dimension_record_results( + raw_rows: Iterable[sqlalchemy.Row], + spec: DimensionRecordResultSpec, + next_key: PageKey | None, + name_shrinker: NameShrinker, +) -> DimensionRecordResultPage: + record_set = DimensionRecordSet(spec.element) + columns = spec.get_result_columns() + column_mapping = [ + (field, name_shrinker.shrink(columns.get_qualified_name(spec.element.name, field))) + for field in spec.element.schema.names + ] + record_cls = spec.element.RecordClass + if not spec.element.temporal: + for raw_row in raw_rows: + record_set.add(record_cls(**{k: raw_row._mapping[v] for k, v in column_mapping})) + return DimensionRecordResultPage(spec=spec, next_key=next_key, rows=record_set) diff --git a/python/lsst/daf/butler/direct_query_driver/_driver.py b/python/lsst/daf/butler/direct_query_driver/_driver.py new file mode 100644 index 0000000000..654c4f80bc --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_driver.py @@ -0,0 +1,789 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import uuid + +__all__ = ("DirectQueryDriver",) + +import logging +from collections.abc import Iterable, Iterator, Sequence +from contextlib import ExitStack +from typing import TYPE_CHECKING, Any, cast, overload + +import sqlalchemy + +from .. import ddl +from ..dimensions import DataIdValue, DimensionGroup, DimensionUniverse +from ..queries import tree as qt +from ..queries.driver import ( + DataCoordinateResultPage, + DatasetRefResultPage, + DimensionRecordResultPage, + GeneralResultPage, + PageKey, + QueryDriver, + ResultPage, +) +from ..queries.result_specs import ( + DataCoordinateResultSpec, + DatasetRefResultSpec, + DimensionRecordResultSpec, + GeneralResultSpec, + ResultSpec, +) +from ..registry import CollectionSummary, CollectionType, NoDefaultCollectionError, RegistryDefaults +from ..registry.interfaces import ChainedCollectionRecord, CollectionRecord +from ..registry.managers import RegistryManagerInstances +from ..registry.nameShrinker import NameShrinker +from ._analyzed_query import AnalyzedDatasetSearch, AnalyzedQuery, DataIdExtractionVisitor +from ._convert_results import convert_dimension_record_results +from ._sql_column_visitor import SqlColumnVisitor + +if TYPE_CHECKING: + from ..registry.interfaces import Database + from ._postprocessing import Postprocessing + from ._sql_builder import SqlBuilder + + +_LOG = logging.getLogger(__name__) + + +class DirectQueryDriver(QueryDriver): + """The `QueryDriver` implementation for `DirectButler`. + + Parameters + ---------- + db : `Database` + Abstraction for the SQL database. + universe : `DimensionUniverse` + Definitions of all dimensions. + managers : `RegistryManagerInstances` + Struct of registry manager objects. + defaults : `RegistryDefaults` + Struct holding the default collection search path and governor + dimensions. + raw_page_size : `int`, optional + Number of database rows to fetch for each result page. The actual + number of rows in a page may be smaller due to postprocessing. + postprocessing_filter_factor : `int`, optional + The number of database rows we expect to have to fetch to yield a + single output row for queries that involve postprocessing. This is + purely a performance tuning parameter that attempts to balance between + fetching too much and requiring multiple fetches; the true value is + highly dependent on the actual query. + """ + + def __init__( + self, + db: Database, + universe: DimensionUniverse, + managers: RegistryManagerInstances, + defaults: RegistryDefaults, + raw_page_size: int = 10000, + postprocessing_filter_factor: int = 10, + ): + self.db = db + self.managers = managers + self._universe = universe + self._defaults = defaults + self._materializations: dict[qt.MaterializationKey, tuple[sqlalchemy.Table, Postprocessing]] = {} + self._upload_tables: dict[qt.DataCoordinateUploadKey, sqlalchemy.Table] = {} + self._exit_stack: ExitStack | None = None + self._raw_page_size = raw_page_size + self._postprocessing_filter_factor = postprocessing_filter_factor + self._active_pages: dict[PageKey, tuple[Iterator[Sequence[sqlalchemy.Row]], Postprocessing]] = {} + self._name_shrinker = NameShrinker(self.db.dialect.max_identifier_length) + + def __enter__(self) -> None: + self._exit_stack = ExitStack() + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + assert self._exit_stack is not None + self._exit_stack.__exit__(exc_type, exc_value, traceback) + self._exit_stack = None + + @property + def universe(self) -> DimensionUniverse: + return self._universe + + @overload + def execute( + self, result_spec: DataCoordinateResultSpec, tree: qt.QueryTree + ) -> DataCoordinateResultPage: ... + + @overload + def execute( + self, result_spec: DimensionRecordResultSpec, tree: qt.QueryTree + ) -> DimensionRecordResultPage: ... + + @overload + def execute(self, result_spec: DatasetRefResultSpec, tree: qt.QueryTree) -> DatasetRefResultPage: ... + + @overload + def execute(self, result_spec: GeneralResultSpec, tree: qt.QueryTree) -> GeneralResultPage: ... + + def execute(self, result_spec: ResultSpec, tree: qt.QueryTree) -> ResultPage: + # Docstring inherited. + if self._exit_stack is None: + raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.") + # Make a set of the columns the query needs to make available to the + # SELECT clause and any ORDER BY or GROUP BY clauses. This does not + # include columns needed only by the WHERE or JOIN ON clauses (those + # will be handled inside `_make_vanilla_sql_builder`). + + # Build the FROM and WHERE clauses and identify any post-query + # processing we need to run. + query, sql_builder = self.analyze_query( + tree, + final_columns=result_spec.get_result_columns(), + order_by=result_spec.order_by, + find_first_dataset=result_spec.find_first_dataset, + ) + sql_builder = self.build_query(query, sql_builder) + sql_select = sql_builder.select(query.final_columns, query.postprocessing) + if result_spec.order_by: + visitor = SqlColumnVisitor(sql_builder, self) + sql_select = sql_select.order_by(*[visitor.expect_scalar(term) for term in result_spec.order_by]) + if result_spec.limit is not None: + if query.postprocessing: + query.postprocessing.limit = result_spec.limit + else: + sql_select = sql_select.limit(result_spec.limit) + if result_spec.offset: + if query.postprocessing: + sql_select = sql_select.offset(result_spec.offset) + else: + query.postprocessing.offset = result_spec.offset + if query.postprocessing.limit is not None: + # We might want to fetch many fewer rows that the default page + # size if we have to implement offset and limit in postprocessing. + raw_page_size = min( + self._postprocessing_filter_factor + * (query.postprocessing.offset + query.postprocessing.limit), + self._raw_page_size, + ) + cursor = self._exit_stack.enter_context( + self.db.query(sql_select.execution_options(yield_per=raw_page_size)) + ) + raw_page_iter = cursor.partitions() + return self._process_page(raw_page_iter, result_spec, query.postprocessing) + + @overload + def fetch_next_page( + self, result_spec: DataCoordinateResultSpec, key: PageKey + ) -> DataCoordinateResultPage: ... + + @overload + def fetch_next_page( + self, result_spec: DimensionRecordResultSpec, key: PageKey + ) -> DimensionRecordResultPage: ... + + @overload + def fetch_next_page(self, result_spec: DatasetRefResultSpec, key: PageKey) -> DatasetRefResultPage: ... + + @overload + def fetch_next_page(self, result_spec: GeneralResultSpec, key: PageKey) -> GeneralResultPage: ... + + def fetch_next_page(self, result_spec: ResultSpec, key: PageKey) -> ResultPage: + raw_page_iter, postprocessing = self._active_pages.pop(key) + return self._process_page(raw_page_iter, result_spec, postprocessing) + + def materialize( + self, + tree: qt.QueryTree, + dimensions: DimensionGroup, + datasets: frozenset[str], + ) -> qt.MaterializationKey: + # Docstring inherited. + if self._exit_stack is None: + raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.") + query, sql_builder = self.analyze_query(tree, qt.ColumnSet(dimensions)) + # Current implementation ignores 'datasets' because figuring out what + # to put in the temporary table for them is tricky, especially if + # calibration collections are involved. + sql_builder = self.build_query(query, sql_builder) + sql_select = sql_builder.select(query.final_columns, query.postprocessing) + table = self._exit_stack.enter_context( + self.db.temporary_table(sql_builder.make_table_spec(query.final_columns, query.postprocessing)) + ) + self.db.insert(table, select=sql_select) + key = uuid.uuid4() + self._materializations[key] = (table, query.postprocessing) + return key + + def upload_data_coordinates( + self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]] + ) -> qt.DataCoordinateUploadKey: + # Docstring inherited. + if self._exit_stack is None: + raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.") + table_spec = ddl.TableSpec( + [ + self.universe.dimensions[name].primary_key.model_copy(update=dict(name=name)).to_sql_spec() + for name in dimensions.required + ] + ) + if not dimensions: + table_spec.fields.add( + ddl.FieldSpec( + SqlBuilder.EMPTY_COLUMNS_NAME, dtype=SqlBuilder.EMPTY_COLUMNS_TYPE, nullable=True + ) + ) + table = self._exit_stack.enter_context(self.db.temporary_table(table_spec)) + self.db.insert(table, *(dict(zip(dimensions.required, values)) for values in rows)) + key = uuid.uuid4() + self._upload_tables[key] = table + return key + + def count( + self, + tree: qt.QueryTree, + columns: qt.ColumnSet, + find_first_dataset: str | None, + *, + exact: bool, + discard: bool, + ) -> int: + # Docstring inherited. + query, sql_builder = self.analyze_query(tree, columns, find_first_dataset=find_first_dataset) + sql_builder = self.build_query(query, sql_builder) + if query.postprocessing and exact: + if not discard: + raise RuntimeError("Cannot count query rows exactly without discarding them.") + sql_select = sql_builder.select(columns, query.postprocessing) + n = 0 + with self.db.query(sql_select.execution_options(yield_per=self._raw_page_size)) as results: + for _ in query.postprocessing.apply(results): + n + 1 + return n + # Do COUNT(*) on the original query's FROM clause. + sql_builder.special["_ROWCOUNT"] = sqlalchemy.func.count() + sql_select = sql_builder.select(qt.ColumnSet(self._universe.empty.as_group())) + with self.db.query(sql_select) as result: + return cast(int, result.scalar()) + + def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool: + # Docstring inherited. + query, sql_builder = self.analyze_query(tree, qt.ColumnSet(tree.dimensions)) + if not all(d.collection_records for d in query.datasets.values()): + return False + if not execute: + if exact: + raise RuntimeError("Cannot obtain exact result for 'any' without executing.") + return True + sql_builder = self.build_query(query, sql_builder) + if query.postprocessing and exact: + sql_select = sql_builder.select(query.final_columns, query.postprocessing) + with self.db.query( + sql_select.execution_options(yield_per=self._postprocessing_filter_factor) + ) as result: + for _ in query.postprocessing.apply(result): + return True + return False + sql_select = sql_builder.select(query.final_columns).limit(1) + with self.db.query(sql_select) as result: + return result.first() is not None + + def explain_no_results(self, tree: qt.QueryTree, execute: bool) -> Iterable[str]: + # Docstring inherited. + query, _ = self.analyze_query(tree, qt.ColumnSet(tree.dimensions)) + if query.messages or not execute: + return query.messages + # TODO: guess at ways to split up query that might fail or succeed if + # run separately, execute them with LIMIT 1 and report the results. + return [] + + def get_dataset_dimensions(self, name: str) -> DimensionGroup: + # Docstring inherited + return self.managers.datasets[name].datasetType.dimensions.as_group() + + def get_default_collections(self) -> tuple[str, ...]: + # Docstring inherited. + if not self._defaults.collections: + raise NoDefaultCollectionError("No collections provided and no default collections.") + return tuple(self._defaults.collections) + + def resolve_collection_path( + self, collections: Iterable[str] + ) -> list[tuple[CollectionRecord, CollectionSummary]]: + result: list[tuple[CollectionRecord, CollectionSummary]] = [] + done: set[str] = set() + + def recurse(collection_names: Iterable[str]) -> None: + for collection_name in collection_names: + if collection_name not in done: + done.add(collection_name) + record = self.managers.collections.find(collection_name) + + if record.type is CollectionType.CHAINED: + recurse(cast(ChainedCollectionRecord, record).children) + else: + result.append((record, self.managers.datasets.getCollectionSummary(record))) + + recurse(collections) + + return result + + def analyze_query( + self, + tree: qt.QueryTree, + final_columns: qt.ColumnSet, + order_by: Iterable[qt.OrderExpression] = (), + find_first_dataset: str | None = None, + ) -> tuple[AnalyzedQuery, SqlBuilder]: + # Delegate to the dimensions manager to rewrite the predicate and + # start a SqlBuilder and Postprocessing to cover any spatial overlap + # joins or constraints. We'll return that SqlBuilder at the end. + ( + predicate, + sql_builder, + postprocessing, + ) = self.managers.dimensions.process_query_overlaps( + tree.dimensions, + tree.predicate, + tree.join_operand_dimensions, + ) + # Initialize the AnalyzedQuery instance we'll update throughout this + # method. + query = AnalyzedQuery( + predicate, + postprocessing, + base_columns=qt.ColumnSet(tree.dimensions), + projection_columns=final_columns.copy(), + final_columns=final_columns, + find_first_dataset=find_first_dataset, + ) + # The base query needs to join in all columns required by the + # predicate. + predicate.gather_required_columns(query.base_columns) + # The "projection" query differs from the final query by not omitting + # any dimension keys (since that makes it easier to reason about), + # including any columns needed by order_by terms, and including + # the dataset rank if there's a find-first search in play. + query.projection_columns.restore_dimension_keys() + for term in order_by: + term.gather_required_columns(query.projection_columns) + if query.find_first_dataset is not None: + query.projection_columns.dataset_fields[query.find_first_dataset].add("collection_key") + # The base query also needs to include all columns needed by the + # downstream projection query. + query.base_columns.update(query.projection_columns) + # Extract the data ID implied by the predicate; we can use the governor + # dimensions in that to constrain the collections we search for + # datasets later. + query.predicate.visit(DataIdExtractionVisitor(query.constraint_data_id, query.messages)) + # We also check that the predicate doesn't reference any dimensions + # without constraining their governor dimensions, since that's a + # particularly easy mistake to make and it's almost never intentional. + # We also also the registry data ID values to provide governor values. + where_columns = qt.ColumnSet(query.universe.empty.as_group()) + query.predicate.gather_required_columns(where_columns) + for governor in where_columns.dimensions.governors: + if governor not in query.constraint_data_id: + if governor in self._defaults.dataId.dimensions: + query.constraint_data_id[governor] = self._defaults.dataId[governor] + else: + raise qt.InvalidQueryTreeError( + f"Query 'where' expression references a dimension dependent on {governor} without " + "constraining it directly." + ) + # Add materializations, which can also bring in more postprocessing. + for m_key, m_dimensions in tree.materializations.items(): + _, m_postprocessing = self._materializations[m_key] + query.materializations[m_key] = m_dimensions + # When a query is materialized, the new tree's has an empty + # (trivially true) predicate, and the materialization prevents the + # creation of automatic spatial joins that are already included in + # the materialization, so we don't need to deduplicate these + # filters. It's possible for there to be duplicates, but only if + # the user explicitly adds a redundant constraint, and we'll still + # behave correctly (just less efficiently) if that happens. + postprocessing.spatial_join_filtering.extend(m_postprocessing.spatial_join_filtering) + postprocessing.spatial_where_filtering.extend(m_postprocessing.spatial_where_filtering) + # Add data coordinate uploads. + query.data_coordinate_uploads.update(tree.data_coordinate_uploads) + # Add dataset_searches and filter out collections that don't have the + # right dataset type or governor dimensions. + name_shrinker = make_dataset_name_shrinker(self.db.dialect) + for dataset_type_name, dataset_search in tree.datasets.items(): + dataset = AnalyzedDatasetSearch( + dataset_type_name, name_shrinker.shrink(dataset_type_name), dataset_search.dimensions + ) + for collection_record, collection_summary in self.resolve_collection_path( + dataset_search.collections + ): + rejected: bool = False + if dataset.name not in collection_summary.dataset_types.names: + dataset.messages.append( + f"No datasets of type {dataset.name!r} in collection {collection_record.name}." + ) + rejected = True + for governor in query.constraint_data_id.keys() & collection_summary.governors.keys(): + if query.constraint_data_id[governor] not in collection_summary.governors[governor]: + dataset.messages.append( + f"No datasets with {governor}={query.constraint_data_id[governor]!r} " + f"in collection {collection_record.name}." + ) + rejected = True + if not rejected: + if collection_record.type is CollectionType.CALIBRATION: + dataset.is_calibration_search = True + dataset.collection_records.append(collection_record) + if dataset.dimensions != self.get_dataset_type(dataset_type_name).dimensions.as_group(): + # This is really for server-side defensiveness; it's hard to + # imagine the query getting different dimensions for a dataset + # type in two calls to the same query driver. + raise qt.InvalidQueryTreeError( + f"Incorrect dimensions {dataset.dimensions} for dataset {dataset_type_name} " + f"in query (vs. {self.get_dataset_type(dataset_type_name).dimensions.as_group()})." + ) + query.datasets[dataset_type_name] = dataset + if not dataset.collection_records: + query.messages.append(f"Search for dataset type {dataset_type_name!r} is doomed to fail.") + query.messages.extend(dataset.messages) + # Set flags that indicate certain kinds of special processing the query + # will need, mostly in the "projection" stage, where we might do a + # GROUP BY or DISTINCT [ON]. + if query.find_first_dataset is not None: + # If we're doing a find-first search and there's a calibration + # collection in play, we need to make sure the rows coming out of + # the base query have only one timespan for each data ID + + # collection, and we can only do that with a GROUP BY and COUNT. + query.postprocessing.check_validity_match_count = query.datasets[ + query.find_first_dataset + ].is_calibration_search + # We only actually need to include the find-first resolution query + # logic if there's more than one collection. + query.needs_find_first_resolution = ( + len(query.datasets[query.find_first_dataset].collection_records) > 1 + ) + if query.projection_columns.dimensions != query.base_columns.dimensions: + # We're going from a larger set of dimensions to a smaller set, + # that means we'll be doing a SELECT DISTINCT [ON] or GROUP BY. + query.needs_dimension_distinct = True + # If there are any dataset fields being propagated through that + # projection and there is more than one collection, we need to + # include the collection_key column so we can use that as one of + # the DISTINCT ON or GROUP BY columns. + for dataset_type, fields_for_dataset in query.projection_columns.dataset_fields.items(): + if len(query.datasets[dataset_type].collection_records) > 1: + fields_for_dataset.add("collection_key") + # If there's a projection and we're doing postprocessing, we might + # be collapsing the dimensions of the postprocessing regions. When + # that happens, we want to apply an aggregate function to them that + # computes the union of the regions that are grouped together. + for element in query.postprocessing.iter_missing(query.projection_columns): + if element.name not in query.projection_columns.dimensions.elements: + query.projection_region_aggregates.append(element) + break + return query, sql_builder + + def build_query(self, query: AnalyzedQuery, sql_builder: SqlBuilder) -> SqlBuilder: + sql_builder = self._build_base_query(query, sql_builder) + if query.needs_projection: + sql_builder = self._project_query(query, sql_builder) + if query.needs_find_first_resolution: + sql_builder = self._apply_find_first(query, sql_builder) + elif query.needs_find_first_resolution: + sql_builder = self._apply_find_first( + query, sql_builder.cte(query.projection_columns, query.postprocessing) + ) + return sql_builder + + def _build_base_query(self, query: AnalyzedQuery, sql_builder: SqlBuilder) -> SqlBuilder: + # Process data coordinate upload joins. + for upload_key, upload_dimensions in query.data_coordinate_uploads.items(): + sql_builder = sql_builder.join( + SqlBuilder(self.db, self._upload_tables[upload_key]).extract_dimensions( + upload_dimensions.required + ) + ) + # Process materialization joins. + for materialization_key, materialization_spec in query.materializations.items(): + sql_builder = self._join_materialization(sql_builder, materialization_key, materialization_spec) + # Process dataset joins. + for dataset_type, dataset_search in query.datasets.items(): + sql_builder = self._join_dataset_search( + sql_builder, + dataset_type, + dataset_search, + query.base_columns, + ) + # Join in dimension element tables that we know we need relationships + # or columns from. + for element in query.iter_mandatory_base_elements(): + sql_builder = sql_builder.join( + self.managers.dimensions.make_sql_builder( + element, query.base_columns.dimension_fields[element.name] + ) + ) + # See if any dimension keys are still missing, and if so join in their + # tables. Note that we know there are no fields needed from these. + while not (sql_builder.dimension_keys.keys() >= query.base_columns.dimensions.names): + # Look for opportunities to join in multiple dimensions via single + # table, to reduce the total number of tables joined in. + missing_dimension_names = query.base_columns.dimensions.names - sql_builder.dimension_keys.keys() + best = self._universe[ + max( + missing_dimension_names, + key=lambda name: len(self._universe[name].dimensions.names & missing_dimension_names), + ) + ] + sql_builder = sql_builder.join(self.managers.dimensions.make_sql_builder(best, frozenset())) + # Add the WHERE clause to the builder. + return sql_builder.where_sql(query.predicate.visit(SqlColumnVisitor(sql_builder, self))) + + def _project_query(self, query: AnalyzedQuery, sql_builder: SqlBuilder) -> SqlBuilder: + assert query.needs_projection + # This method generates a Common Table Expresssion (CTE) using either a + # SELECT DISTINCT [ON] or a SELECT with GROUP BY. + # We'll work out which as we go + have_aggregates: bool = False + # Dimension key columns form at least most of our GROUP BY or DISTINCT + # ON clause; we'll work out which of those we'll use. + unique_keys: list[sqlalchemy.ColumnElement[Any]] = [ + sql_builder.dimension_keys[k][0] for k in query.projection_columns.dimensions.data_coordinate_keys + ] + # There are two reasons we might need an aggregate function: + # - to make sure temporal constraints and joins have resulted in at + # most one validity range match for each data ID and collection, + # when we're doing a find-first query. + # - to compute the unions of regions we need for postprocessing, when + # the data IDs for those regions are not wholly included in the + # results (i.e. we need to postprocess on + # visit_detector_region.region, but the output rows don't have + # detector, just visit - so we compute the union of the + # visit_detector region over all matched detectors). + if query.postprocessing.check_validity_match_count: + sql_builder.special[query.postprocessing.VALIDITY_MATCH_COUNT] = sqlalchemy.func.count().label( + query.postprocessing.VALIDITY_MATCH_COUNT + ) + have_aggregates = True + for element in query.projection_region_aggregates: + sql_builder.fields[element.name]["region"] = ddl.Base64Region.union_aggregate( + sql_builder.fields[element.name]["region"] + ) + have_aggregates = True + # Many of our fields derive their uniqueness from the unique_key + # fields: if rows are uniqe over the 'unique_key' fields, then they're + # automatically unique over these 'derived_fields'. We just remember + # these as pairs of (logical_table, field) for now. + derived_fields: list[tuple[str, str]] = [] + # All dimension record fields are derived fields. + for element_name, fields_for_element in query.projection_columns.dimension_fields.items(): + for element_field in fields_for_element: + derived_fields.append((element_name, element_field)) + # Some dataset fields are derived fields and some are unique keys, and + # it depends on the kinds of collection(s) we're searching and whether + # it's a find-first query. + for dataset_type, fields_for_dataset in query.projection_columns.dataset_fields.items(): + for dataset_field in fields_for_dataset: + if dataset_field == "collection_key": + # If the collection_key field is present, it's needed for + # uniqueness if we're looking in more than one collection. + # If not, it's a derived field. + if len(query.datasets[dataset_type].collection_records) > 1: + unique_keys.append(sql_builder.fields[dataset_type]["collection_key"]) + else: + derived_fields.append((dataset_type, "collection_key")) + elif dataset_field == "timespan" and query.datasets[dataset_type].is_calibration_search: + # If we're doing a non-find-first query against a + # CALIBRATION collection, the timespan is also a unique + # key... + if dataset_type == query.find_first_dataset: + # ...unless we're doing a find-first search on this + # dataset, in which case we need to use ANY_VALUE on + # the timespan and check that _VALIDITY_MATCH_COUNT + # (added earlier) is one, indicating that there was + # indeed only one timespan for each data ID in each + # collection that survived the base query's WHERE + # clauses and JOINs. + if not self.db.has_any_aggregate: + raise NotImplementedError( + f"Cannot generate query that returns {dataset_type}.timespan after a " + "find-first search, because this a database does not support the ANY_VALUE " + "aggregate function (or equivalent)." + ) + sql_builder.timespans[dataset_type] = sql_builder.timespans[ + dataset_type + ].apply_any_aggregate(self.db.apply_any_aggregate) + else: + unique_keys.extend(sql_builder.timespans[dataset_type].flatten()) + else: + # Other dataset fields derive their uniqueness from key + # fields. + derived_fields.append((dataset_type, dataset_field)) + if not have_aggregates and not derived_fields: + # SELECT DISTINCT is sufficient. + return sql_builder.cte(query.projection_columns, query.postprocessing, distinct=True) + elif not have_aggregates and self.db.has_distinct_on: + # SELECT DISTINCT ON is sufficient and works. + return sql_builder.cte(query.projection_columns, query.postprocessing, distinct=unique_keys) + else: + # GROUP BY is the only option. + if derived_fields: + if self.db.has_any_aggregate: + for logical_table, field in derived_fields: + if field == "timespan": + sql_builder.timespans[logical_table] = sql_builder.timespans[ + logical_table + ].apply_any_aggregate(self.db.apply_any_aggregate) + else: + sql_builder.fields[logical_table][field] = self.db.apply_any_aggregate( + sql_builder.fields[logical_table][field] + ) + else: + _LOG.warning( + "Adding %d fields to GROUP BY because this database backend does not support the " + "ANY_VALUE aggregate function (or equivalent). This may result in a poor query " + "plan. Materializing the query first sometimes avoids this problem.", + len(derived_fields), + ) + for logical_table, field in derived_fields: + if field == "timespan": + unique_keys.extend(sql_builder.timespans[logical_table].flatten()) + else: + unique_keys.append(sql_builder.fields[logical_table][field]) + return sql_builder.cte(query.projection_columns, query.postprocessing, group_by=unique_keys) + + def _apply_find_first(self, query: AnalyzedQuery, sql_builder: SqlBuilder) -> SqlBuilder: + assert query.needs_find_first_resolution + assert query.find_first_dataset is not None + assert sql_builder.sql_from_clause is not None + # The query we're building looks like this: + # + # WITH {dst}_base AS ( + # {target} + # ... + # ) + # SELECT + # {dst}_window.*, + # FROM ( + # SELECT + # {dst}_base.*, + # ROW_NUMBER() OVER ( + # PARTITION BY {dst_base}.{dimensions} + # ORDER BY {rank} + # ) AS rownum + # ) {dst}_window + # WHERE + # {dst}_window.rownum = 1; + # + # The outermost SELECT will be represented by the SqlBuilder we return. + + # The sql_builder we're given corresponds to the Common Table + # Expression (CTE) at the top, and is guaranteed to have + # ``query.projected_columns`` (+ postprocessing columns). + # We start by filling out the "window" SELECT statement... + partition_by = [sql_builder.dimension_keys[d][0] for d in query.base_columns.dimensions.required] + rank_sql_column = sqlalchemy.case( + { + record.key: n + for n, record in enumerate(query.datasets[query.find_first_dataset].collection_records) + }, + value=sql_builder.fields[query.find_first_dataset]["collection_key"], + ) + if partition_by: + sql_builder.special["_ROWNUM"] = sqlalchemy.sql.func.row_number().over( + partition_by=partition_by, order_by=rank_sql_column + ) + else: + sql_builder.special["_ROWNUM"] = sqlalchemy.sql.func.row_number().over(order_by=rank_sql_column) + # ... and then turn that into a subquery with a constraint on rownum. + sql_builder = sql_builder.subquery(query.base_columns, query.postprocessing) + sql_builder = sql_builder.where_sql(sql_builder.special["_ROWNUM"] == 1) + del sql_builder.special["_ROWNUM"] + return sql_builder + + def _join_materialization( + self, + sql_builder: SqlBuilder, + materialization_key: qt.MaterializationKey, + dimensions: DimensionGroup, + ) -> SqlBuilder: + columns = qt.ColumnSet(dimensions) + table, postprocessing = self._materializations[materialization_key] + return sql_builder.join(SqlBuilder(self.db, table).extract_columns(columns, postprocessing)) + + def _join_dataset_search( + self, + sql_builder: SqlBuilder, + dataset_type: str, + processed_dataset_search: AnalyzedDatasetSearch, + columns: qt.ColumnSet, + ) -> SqlBuilder: + storage = self.managers.datasets[dataset_type] + # The next two asserts will need to be dropped (and the implications + # dealt with instead) if materializations start having dataset fields. + assert ( + dataset_type not in sql_builder.fields + ), "Dataset fields have unexpected already been joined in." + assert ( + dataset_type not in sql_builder.timespans + ), "Dataset timespan has unexpected already been joined in." + return sql_builder.join( + storage.make_sql_builder( + processed_dataset_search.collection_records, columns.dataset_fields[dataset_type] + ) + ) + + def _process_page( + self, + raw_page_iter: Iterator[Sequence[sqlalchemy.Row]], + result_spec: ResultSpec, + postprocessing: Postprocessing, + ) -> ResultPage: + try: + raw_page = next(raw_page_iter) + except StopIteration: + raw_page = tuple() + if len(raw_page) == self._raw_page_size: + # There's some chance we got unlucky and this page exactly finishes + # off the query, and we won't know the next page does not exist + # until we try to fetch it. But that's better than always fetching + # the next page up front. + next_key = uuid.uuid4() + self._active_pages[next_key] = (raw_page_iter, postprocessing) + else: + next_key = None + match result_spec: + case DimensionRecordResultSpec(): + return convert_dimension_record_results( + postprocessing.apply(raw_page), + result_spec, + next_key, + self._name_shrinker, + ) + case _: + raise NotImplementedError("TODO") + + +def make_dataset_name_shrinker(dialect: sqlalchemy.Dialect) -> NameShrinker: + max_dataset_field_length = max(len(field) for field in qt.DATASET_FIELD_NAMES) + return NameShrinker(dialect.max_identifier_length - max_dataset_field_length - 1, 6) diff --git a/python/lsst/daf/butler/direct_query_driver/_postprocessing.py b/python/lsst/daf/butler/direct_query_driver/_postprocessing.py new file mode 100644 index 0000000000..2dc4ead742 --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_postprocessing.py @@ -0,0 +1,138 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("Postprocessing", "ValidityRangeMatchError") + +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING, ClassVar + +import sqlalchemy +from lsst.sphgeom import DISJOINT, Region + +from ..queries import tree as qt + +if TYPE_CHECKING: + from ..dimensions import DimensionElement + + +class ValidityRangeMatchError(RuntimeError): + pass + + +class Postprocessing: + def __init__(self) -> None: + self.spatial_join_filtering: list[tuple[DimensionElement, DimensionElement]] = [] + self.spatial_where_filtering: list[tuple[DimensionElement, Region]] = [] + self.check_validity_match_count: bool = False + self._offset: int = 0 + self._limit: int | None = None + + VALIDITY_MATCH_COUNT: ClassVar[str] = "_VALIDITY_MATCH_COUNT" + + @property + def offset(self) -> int: + return self._offset + + @offset.setter + def offset(self, value: int) -> None: + if value and not self: + raise RuntimeError( + "Postprocessing should only implement 'offset' if it needs to do spatial filtering." + ) + self._offset = value + + @property + def limit(self) -> int | None: + return self._limit + + @limit.setter + def limit(self, value: int | None) -> None: + if value and not self: + raise RuntimeError( + "Postprocessing should only implement 'limit' if it needs to do spatial filtering." + ) + self._limit = value + + def __bool__(self) -> bool: + return bool(self.spatial_join_filtering or self.spatial_where_filtering) + + def gather_columns_required(self, columns: qt.ColumnSet) -> None: + for element in self.iter_region_dimension_elements(): + columns.update_dimensions(element.minimal_group) + columns.dimension_fields[element.name].add("region") + + def iter_region_dimension_elements(self) -> Iterator[DimensionElement]: + for a, b in self.spatial_join_filtering: + yield a + yield b + for element, _ in self.spatial_where_filtering: + yield element + + def iter_missing(self, columns: qt.ColumnSet) -> Iterator[DimensionElement]: + done: set[DimensionElement] = set() + for element in self.iter_region_dimension_elements(): + if element not in done: + if "region" not in columns.dimension_fields.get(element.name, frozenset()): + yield element + done.add(element) + + def apply(self, rows: Iterable[sqlalchemy.Row]) -> Iterable[sqlalchemy.Row]: + if not self: + yield from rows + joins = [ + ( + qt.ColumnSet.get_qualified_name(a.name, "region"), + qt.ColumnSet.get_qualified_name(b.name, "region"), + ) + for a, b in self.spatial_join_filtering + ] + where = [ + (qt.ColumnSet.get_qualified_name(element.name, "region"), region) + for element, region in self.spatial_where_filtering + ] + for row in rows: + m = row._mapping + if any(m[a].relate(m[b]) & DISJOINT for a, b in joins) or any( + m[field].relate(region) & DISJOINT for field, region in where + ): + continue + if self.check_validity_match_count and m[self.VALIDITY_MATCH_COUNT] > 1: + raise ValidityRangeMatchError( + "Ambiguous calibration validity range match. This usually means a temporal join or " + "'where' needs to be added, but it could also mean that multiple validity ranges " + "overlap a single output data ID." + ) + if self._offset: + self._offset -= 1 + continue + if self._limit == 0: + break + yield row + if self._limit is not None: + self._limit -= 1 diff --git a/python/lsst/daf/butler/direct_query_driver/_sql_builder.py b/python/lsst/daf/butler/direct_query_driver/_sql_builder.py new file mode 100644 index 0000000000..84726f4bc5 --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_sql_builder.py @@ -0,0 +1,257 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("SqlBuilder",) + +import dataclasses +import itertools +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING, Any, ClassVar + +import sqlalchemy + +from .. import ddl +from ..nonempty_mapping import NonemptyMapping +from ..queries import tree as qt +from ._postprocessing import Postprocessing + +if TYPE_CHECKING: + from ..registry.interfaces import Database + from ..timespan_database_representation import TimespanDatabaseRepresentation + + +@dataclasses.dataclass +class SqlBuilder: + db: Database + sql_from_clause: sqlalchemy.FromClause | None = None + sql_where_terms: list[sqlalchemy.ColumnElement[bool]] = dataclasses.field(default_factory=list) + needs_distinct: bool = False + + dimension_keys: NonemptyMapping[str, list[sqlalchemy.ColumnElement]] = dataclasses.field( + default_factory=lambda: NonemptyMapping(list) + ) + + fields: NonemptyMapping[str, dict[str, sqlalchemy.ColumnElement[Any]]] = dataclasses.field( + default_factory=lambda: NonemptyMapping(dict) + ) + + timespans: dict[str, TimespanDatabaseRepresentation] = dataclasses.field(default_factory=dict) + + special: dict[str, sqlalchemy.ColumnElement[Any]] = dataclasses.field(default_factory=dict) + + EMPTY_COLUMNS_NAME: ClassVar[str] = "IGNORED" + """Name of the column added to a SQL ``SELECT`` query in order to represent + relations that have no real columns. + """ + + EMPTY_COLUMNS_TYPE: ClassVar[type] = sqlalchemy.Boolean + """Type of the column added to a SQL ``SELECT`` query in order to represent + relations that have no real columns. + """ + + @property + def sql_columns(self) -> sqlalchemy.ColumnCollection: + assert self.sql_from_clause is not None + return self.sql_from_clause.columns + + @classmethod + def handle_empty_columns( + cls, columns: list[sqlalchemy.sql.ColumnElement] + ) -> list[sqlalchemy.ColumnElement]: + """Handle the edge case where a SELECT statement has no columns, by + adding a literal column that should be ignored. + + Parameters + ---------- + columns : `list` [ `sqlalchemy.ColumnElement` ] + List of SQLAlchemy column objects. This may have no elements when + this method is called, and will always have at least one element + when it returns. + + Returns + ------- + columns : `list` [ `sqlalchemy.ColumnElement` ] + The same list that was passed in, after any modification. + """ + if not columns: + columns.append(sqlalchemy.sql.literal(True).label(cls.EMPTY_COLUMNS_NAME)) + return columns + + def select( + self, + columns: qt.ColumnSet, + postprocessing: Postprocessing | None = None, + *, + distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = False, + group_by: Sequence[sqlalchemy.ColumnElement] = (), + ) -> sqlalchemy.Select: + sql_columns: list[sqlalchemy.ColumnElement[Any]] = [] + for logical_table, field in columns: + name = columns.get_qualified_name(logical_table, field) + if field is None: + sql_columns.append(self.dimension_keys[logical_table][0].label(name)) + elif columns.is_timespan(logical_table, field): + sql_columns.extend(self.timespans[logical_table].flatten(name)) + else: + sql_columns.append(self.fields[logical_table][field].label(name)) + if postprocessing is not None: + for element in postprocessing.iter_missing(columns): + assert ( + element.name in columns.dimensions.elements + ), "Region aggregates not handled by this method." + sql_columns.append( + self.fields[element.name]["region"].label( + columns.get_qualified_name(element.name, "region") + ) + ) + for label, sql_column in self.special.items(): + sql_columns.append(sql_column.label(label)) + self.handle_empty_columns(sql_columns) + result = sqlalchemy.select(*sql_columns) + if self.sql_from_clause is not None: + result = result.select_from(self.sql_from_clause) + if self.needs_distinct or distinct: + if distinct is True or distinct is False: + result = result.distinct() + else: + result = result.distinct(*distinct) + if group_by: + result = result.group_by(*group_by) + if self.sql_where_terms: + result = result.where(*self.sql_where_terms) + return result + + def make_table_spec( + self, + columns: qt.ColumnSet, + postprocessing: Postprocessing | None = None, + ) -> ddl.TableSpec: + assert not self.special, "special columns not supported in make_table_spec" + results = ddl.TableSpec( + [columns.get_column_spec(logical_table, field).to_sql_spec() for logical_table, field in columns] + ) + if postprocessing: + for element in postprocessing.iter_missing(columns): + results.fields.add( + ddl.FieldSpec.for_region(columns.get_qualified_name(element.name, "region")) + ) + return results + + def extract_dimensions(self, dimensions: Iterable[str], **kwargs: str) -> SqlBuilder: + assert self.sql_from_clause is not None, "Cannot extract columns with no FROM clause." + for dimension_name in dimensions: + self.dimension_keys[dimension_name].append(self.sql_from_clause.columns[dimension_name]) + for k, v in kwargs.items(): + self.dimension_keys[v].append(self.sql_from_clause.columns[k]) + return self + + def extract_columns( + self, columns: qt.ColumnSet, postprocessing: Postprocessing | None = None + ) -> SqlBuilder: + assert self.sql_from_clause is not None, "Cannot extract columns with no FROM clause." + for logical_table, field in columns: + name = columns.get_qualified_name(logical_table, field) + if field is None: + self.dimension_keys[logical_table].append(self.sql_from_clause.columns[name]) + elif columns.is_timespan(logical_table, field): + self.timespans[logical_table] = self.db.getTimespanRepresentation().from_columns( + self.sql_from_clause.columns, name + ) + else: + self.fields[logical_table][field] = self.sql_from_clause.columns[name] + if postprocessing is not None: + for element in postprocessing.iter_missing(columns): + self.fields[element.name]["region"] = self.sql_from_clause.columns[name] + if postprocessing.check_validity_match_count: + self.special[postprocessing.VALIDITY_MATCH_COUNT] = self.sql_from_clause.columns[ + postprocessing.VALIDITY_MATCH_COUNT + ] + return self + + def join(self, other: SqlBuilder) -> SqlBuilder: + join_on: list[sqlalchemy.ColumnElement] = [] + for dimension_name in self.dimension_keys.keys() & other.dimension_keys.keys(): + for column1, column2 in itertools.product( + self.dimension_keys[dimension_name], other.dimension_keys[dimension_name] + ): + join_on.append(column1 == column2) + self.dimension_keys[dimension_name].extend(other.dimension_keys[dimension_name]) + if self.sql_from_clause is None: + self.sql_from_clause = other.sql_from_clause + elif other.sql_from_clause is not None: + self.sql_from_clause = self.sql_from_clause.join( + other.sql_from_clause, onclause=sqlalchemy.and_(*join_on) + ) + self.sql_where_terms += other.sql_where_terms + self.needs_distinct = self.needs_distinct or other.needs_distinct + self.special.update(other.special) + return self + + def where_sql(self, *arg: sqlalchemy.ColumnElement[bool]) -> SqlBuilder: + self.sql_where_terms.extend(arg) + return self + + def cte( + self, + columns: qt.ColumnSet, + postprocessing: Postprocessing | None = None, + *, + distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = False, + group_by: Sequence[sqlalchemy.ColumnElement] = (), + ) -> SqlBuilder: + return SqlBuilder( + self.db, + self.select(columns, postprocessing, distinct=distinct, group_by=group_by).cte(), + ).extract_columns(columns, postprocessing) + + def subquery( + self, + columns: qt.ColumnSet, + postprocessing: Postprocessing | None = None, + *, + distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = False, + group_by: Sequence[sqlalchemy.ColumnElement] = (), + ) -> SqlBuilder: + return SqlBuilder( + self.db, + self.select(columns, postprocessing, distinct=distinct, group_by=group_by).subquery(), + ).extract_columns(columns, postprocessing) + + def union_subquery( + self, + others: Iterable[SqlBuilder], + columns: qt.ColumnSet, + postprocessing: Postprocessing | None = None, + ) -> SqlBuilder: + select0 = self.select(columns, postprocessing) + other_selects = [other.select(columns, postprocessing) for other in others] + return SqlBuilder( + self.db, + select0.union(*other_selects).subquery(), + ).extract_columns(columns, postprocessing) diff --git a/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py b/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py new file mode 100644 index 0000000000..d3a6de201f --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py @@ -0,0 +1,239 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("SqlColumnVisitor",) + +from typing import TYPE_CHECKING, Any, cast + +import sqlalchemy + +from .. import ddl +from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor +from ..timespan_database_representation import TimespanDatabaseRepresentation + +if TYPE_CHECKING: + from ..queries import tree as qt + from ._driver import DirectQueryDriver + from ._sql_builder import SqlBuilder + + +class SqlColumnVisitor( + ColumnExpressionVisitor[sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation], + PredicateVisitor[ + sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool] + ], +): + def __init__(self, sql_builder: SqlBuilder, driver: DirectQueryDriver): + self._driver = driver + self._sql_builder = sql_builder + + def visit_literal( + self, expression: qt.ColumnLiteral + ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: + # Docstring inherited. + if expression.column_type == "timespan": + return self._driver.db.getTimespanRepresentation().fromLiteral(expression.get_literal_value()) + return sqlalchemy.literal( + expression.get_literal_value(), type_=ddl.VALID_CONFIG_COLUMN_TYPES[expression.column_type] + ) + + def visit_dimension_key_reference( + self, expression: qt.DimensionKeyReference + ) -> sqlalchemy.ColumnElement[int | str]: + # Docstring inherited. + return self._sql_builder.dimension_keys[expression.dimension.name][0] + + def visit_dimension_field_reference( + self, expression: qt.DimensionFieldReference + ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: + # Docstring inherited. + if expression.column_type == "timespan": + return self._sql_builder.timespans[expression.element.name] + return self._sql_builder.fields[expression.element.name][expression.field] + + def visit_dataset_field_reference( + self, expression: qt.DatasetFieldReference + ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: + # Docstring inherited. + if expression.column_type == "timespan": + return self._sql_builder.timespans[expression.dataset_type] + return self._sql_builder.fields[expression.dataset_type][expression.field] + + def visit_unary_expression(self, expression: qt.UnaryExpression) -> sqlalchemy.ColumnElement[Any]: + # Docstring inherited. + match expression.operator: + case "-": + return -self.expect_scalar(expression.operand) + case "begin_of": + return self.expect_timespan(expression.operand).lower() + case "end_of": + return self.expect_timespan(expression.operand).upper() + raise AssertionError(f"Invalid unary expression operator {expression.operator!r}.") + + def visit_binary_expression(self, expression: qt.BinaryExpression) -> sqlalchemy.ColumnElement[Any]: + # Docstring inherited. + a = self.expect_scalar(expression.a) + b = self.expect_scalar(expression.b) + match expression.operator: + case "+": + return a + b + case "-": + return a - b + case "*": + return a * b + case "/": + return a / b + case "%": + return a % b + raise AssertionError(f"Invalid binary expression operator {expression.operator!r}.") + + def visit_reversed(self, expression: qt.Reversed) -> sqlalchemy.ColumnElement[Any]: + # Docstring inherited. + return self.expect_scalar(expression.operand).desc() + + def visit_comparison( + self, + a: qt.ColumnExpression, + operator: qt.ComparisonOperator, + b: qt.ColumnExpression, + flags: PredicateVisitFlags, + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + if operator == "overlaps": + assert a.column_type == "timespan", "Spatial overlaps should be transformed away by now." + return self.expect_timespan(a).overlaps(self.expect_timespan(b)) + lhs = self.expect_scalar(a) + rhs = self.expect_scalar(b) + match operator: + case "==": + return lhs == rhs + case "!=": + return lhs != rhs + case "<": + return lhs < rhs + case ">": + return lhs > rhs + case "<=": + return lhs <= rhs + case ">=": + return lhs >= rhs + raise AssertionError(f"Invalid comparison operator {operator!r}.") + + def visit_is_null( + self, operand: qt.ColumnExpression, flags: PredicateVisitFlags + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + if operand.column_type == "timespan": + return self.expect_timespan(operand).isNull() + return self.expect_scalar(operand) == sqlalchemy.null() + + def visit_in_container( + self, + member: qt.ColumnExpression, + container: tuple[qt.ColumnExpression, ...], + flags: PredicateVisitFlags, + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + return self.expect_scalar(member).in_([self.expect_scalar(item) for item in container]) + + def visit_in_range( + self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + sql_member = self.expect_scalar(member) + if stop is None: + target = sql_member >= sqlalchemy.literal(start) + else: + stop_inclusive = stop - 1 + if start == stop_inclusive: + return sql_member == sqlalchemy.literal(start) + else: + target = sqlalchemy.sql.between( + sql_member, + sqlalchemy.literal(start), + sqlalchemy.literal(stop_inclusive), + ) + if step != 1: + return sqlalchemy.sql.and_( + *[ + target, + sql_member % sqlalchemy.literal(step) == sqlalchemy.literal(start % step), + ] + ) + else: + return target + + def visit_in_query_tree( + self, + member: qt.ColumnExpression, + column: qt.ColumnExpression, + query_tree: qt.QueryTree, + flags: PredicateVisitFlags, + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + columns = qt.ColumnSet(self._driver.universe.empty.as_group()) + column.gather_required_columns(columns) + query, sql_builder = self._driver.analyze_query(query_tree, columns) + self._driver.build_query(query, sql_builder) + if query.postprocessing: + raise NotImplementedError( + "Right-hand side subquery in IN expression would require postprocessing." + ) + subquery_visitor = SqlColumnVisitor(sql_builder, self._driver) + sql_builder.special["_MEMBER"] = subquery_visitor.expect_scalar(column) + subquery_select = sql_builder.select(qt.ColumnSet(self._driver.universe.empty.as_group())) + sql_member = self.expect_scalar(member) + return sql_member.in_(subquery_select) + + def apply_logical_and( + self, originals: qt.PredicateOperands, results: tuple[sqlalchemy.ColumnElement[bool], ...] + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + return sqlalchemy.and_(*results) + + def apply_logical_or( + self, + originals: tuple[qt.PredicateLeaf, ...], + results: tuple[sqlalchemy.ColumnElement[bool], ...], + flags: PredicateVisitFlags, + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + return sqlalchemy.or_(*results) + + def apply_logical_not( + self, original: qt.PredicateLeaf, result: sqlalchemy.ColumnElement[bool], flags: PredicateVisitFlags + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + return sqlalchemy.not_(result) + + def expect_scalar(self, expression: qt.OrderExpression) -> sqlalchemy.ColumnElement[Any]: + return cast(sqlalchemy.ColumnElement[Any], expression.visit(self)) + + def expect_timespan(self, expression: qt.ColumnExpression) -> TimespanDatabaseRepresentation: + return cast(TimespanDatabaseRepresentation, expression.visit(self)) diff --git a/python/lsst/daf/butler/queries/__init__.py b/python/lsst/daf/butler/queries/__init__.py new file mode 100644 index 0000000000..15743f291f --- /dev/null +++ b/python/lsst/daf/butler/queries/__init__.py @@ -0,0 +1,32 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from ._base import * +from ._data_coordinate_query_results import * +from ._dataset_query_results import * +from ._dimension_record_query_results import * +from ._query import * diff --git a/python/lsst/daf/butler/queries/_base.py b/python/lsst/daf/butler/queries/_base.py new file mode 100644 index 0000000000..a32a1b2967 --- /dev/null +++ b/python/lsst/daf/butler/queries/_base.py @@ -0,0 +1,195 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("QueryBase", "HomogeneousQueryBase", "CountableQueryBase", "QueryResultsBase") + +from abc import ABC, abstractmethod +from collections.abc import Iterable, Set +from typing import Any, Self + +from ..dimensions import DimensionGroup +from .convert_args import convert_order_by_args +from .driver import QueryDriver +from .expression_factory import ExpressionProxy +from .tree import OrderExpression, QueryTree + + +class QueryBase(ABC): + @abstractmethod + def any(self, *, execute: bool = True, exact: bool = True) -> bool: + """Test whether the query would return any rows. + + Parameters + ---------- + execute : `bool`, optional + If `True`, execute at least a ``LIMIT 1`` query if it cannot be + determined prior to execution that the query would return no rows. + exact : `bool`, optional + If `True`, run the full query and perform post-query filtering if + needed, until at least one result row is found. If `False`, the + returned result does not account for post-query filtering, and + hence may be `True` even when all result rows would be filtered + out. + + Returns + ------- + any : `bool` + `True` if the query would (or might, depending on arguments) yield + result rows. `False` if it definitely would not. + """ + raise NotImplementedError() + + @abstractmethod + def explain_no_results(self, execute: bool = True) -> Iterable[str]: + """Return human-readable messages that may help explain why the query + yields no results. + + Parameters + ---------- + execute : `bool`, optional + If `True` (default) execute simplified versions (e.g. ``LIMIT 1``) + of aspects of the tree to more precisely determine where rows were + filtered out. + + Returns + ------- + messages : `~collections.abc.Iterable` [ `str` ] + String messages that describe reasons the query might not yield any + results. + """ + raise NotImplementedError() + + +class HomogeneousQueryBase(QueryBase): + def __init__(self, driver: QueryDriver, tree: QueryTree): + self._driver = driver + self._tree = tree + + @property + def dimensions(self) -> DimensionGroup: + """All dimensions included in the query's columns.""" + return self._tree.dimensions + + def any(self, *, execute: bool = True, exact: bool = True) -> bool: + # Docstring inherited. + return self._driver.any(self._tree, execute=execute, exact=exact) + + def explain_no_results(self, execute: bool = True) -> Iterable[str]: + # Docstring inherited. + return self._driver.explain_no_results(self._tree, execute=execute) + + +class CountableQueryBase(QueryBase): + @abstractmethod + def count(self, *, exact: bool = True, discard: bool = False) -> int: + """Count the number of rows this query would return. + + Parameters + ---------- + exact : `bool`, optional + If `True`, run the full query and perform post-query filtering if + needed to account for that filtering in the count. If `False`, the + result may be an upper bound. + discard : `bool`, optional + If `True`, compute the exact count even if it would require running + the full query and then throwing away the result rows after + counting them. If `False`, this is an error, as the user would + usually be better off executing the query first to fetch its rows + into a new query (or passing ``exact=False``). Ignored if + ``exact=False``. + + Returns + ------- + count : `int` + The number of rows the query would return, or an upper bound if + ``exact=False``. + """ + raise NotImplementedError() + + +class QueryResultsBase(HomogeneousQueryBase, CountableQueryBase): + def order_by(self, *args: str | OrderExpression | ExpressionProxy) -> Self: + """Return a new query that yields ordered results. + + Parameters + ---------- + *args : `str` + Names of the columns/dimensions to use for ordering. Column name + can be prefixed with minus (``-``) to use descending ordering. + + Returns + ------- + result : `QueryResultsBase` + An ordered version of this query results object. + + Notes + ----- + If this method is called multiple times, the new sort terms replace + the old ones. + """ + return self._copy( + self._tree, order_by=convert_order_by_args(self.dimensions, self._get_datasets(), *args) + ) + + def limit(self, limit: int | None = None, offset: int = 0) -> Self: + """Return a new query that slices its result rows positionally. + + Parameters + ---------- + limit : `int` or `None`, optional + Upper limit on the number of returned records. + offset : `int`, optional + The number of records to skip before returning at most ``limit`` + records. + + Returns + ------- + result : `QueryResultsBase` + A sliced version of this query results object. + + Notes + ----- + If this method is called multiple times, the new slice parameters + replace the old ones. Slicing always occurs after sorting, even if + `limit` is called before `order_by`. + """ + return self._copy(self._tree, limit=limit, offset=offset) + + @abstractmethod + def _get_datasets(self) -> Set[str]: + """Return all dataset types included in the query's result rows.""" + raise NotImplementedError() + + @abstractmethod + def _copy(self, tree: QueryTree, **kwargs: Any) -> Self: + """Return a modified copy of ``self``. + + Modifications should be validated, not assumed to be correct. + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/queries/_data_coordinate_query_results.py b/python/lsst/daf/butler/queries/_data_coordinate_query_results.py new file mode 100644 index 0000000000..5e39ccc9b2 --- /dev/null +++ b/python/lsst/daf/butler/queries/_data_coordinate_query_results.py @@ -0,0 +1,142 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("DataCoordinateQueryResults",) + +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING, Any + +from ..dimensions import DataCoordinate, DimensionGroup +from ._base import QueryResultsBase +from .driver import QueryDriver +from .tree import InvalidQueryTreeError, QueryTree + +if TYPE_CHECKING: + from .result_specs import DataCoordinateResultSpec + + +class DataCoordinateQueryResults(QueryResultsBase): + """A method-chaining builder for butler queries that return data IDs. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. + The instance returned directly by the `Butler._query` entry point + should be constructed via `make_unit_query_tree`. + spec : `DataCoordinateResultSpec` + Specification of the query result rows, including output columns, + ordering, and slicing. + + Notes + ----- + This refines the `DataCoordinateQueryResults` ABC defined in + `lsst.daf.butler._query_results`, but the intent is to replace that ABC + with this concrete class, rather than inherit from it. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree, spec: DataCoordinateResultSpec): + spec.validate_tree(tree) + super().__init__(driver, tree) + self._spec = spec + + def __iter__(self) -> Iterator[DataCoordinate]: + page = self._driver.execute(self._spec, self._tree) + yield from page.rows + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield from page.rows + + @property + def has_dimension_records(self) -> bool: + """Whether all data IDs in this iterable contain dimension records.""" + return self._spec.include_dimension_records + + def with_dimension_records(self) -> DataCoordinateQueryResults: + """Return a results object for which `has_dimension_records` is + `True`. + """ + if self.has_dimension_records: + return self + return self._copy(tree=self._tree, include_dimension_records=True) + + def subset( + self, + dimensions: DimensionGroup | Iterable[str] | None = None, + ) -> DataCoordinateQueryResults: + """Return a results object containing a subset of the dimensions of + this one. + + Parameters + ---------- + dimensions : `DimensionGroup` or \ + `~collections.abc.Iterable` [ `str`], optional + Dimensions to include in the new results object. If `None`, + ``self.dimensions`` is used. + + Returns + ------- + results : `DataCoordinateQueryResults` + A results object corresponding to the given criteria. May be + ``self`` if it already qualifies. + + Raises + ------ + InvalidQueryTreeError + Raised when ``dimensions`` is not a subset of the dimensions in + this result. + """ + if dimensions is None: + dimensions = self.dimensions + else: + dimensions = self._driver.universe.conform(dimensions) + if not dimensions <= self.dimensions: + raise InvalidQueryTreeError( + f"New dimensions {dimensions} are not a subset of the current " + f"dimensions {self.dimensions}." + ) + return self._copy(tree=self._tree, dimensions=dimensions) + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + # Docstring inherited. + return self._driver.count( + self._tree, + self._spec.get_result_columns(), + find_first_dataset=None, + exact=exact, + discard=discard, + ) + + def _copy(self, tree: QueryTree, **kwargs: Any) -> DataCoordinateQueryResults: + return DataCoordinateQueryResults(self._driver, tree, spec=self._spec.model_copy(update=kwargs)) + + def _get_datasets(self) -> frozenset[str]: + return frozenset() diff --git a/python/lsst/daf/butler/queries/_dataset_query_results.py b/python/lsst/daf/butler/queries/_dataset_query_results.py new file mode 100644 index 0000000000..ecd33e99e1 --- /dev/null +++ b/python/lsst/daf/butler/queries/_dataset_query_results.py @@ -0,0 +1,232 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "DatasetQueryResults", + "ChainedDatasetQueryResults", + "SingleTypeDatasetQueryResults", +) + +import itertools +from abc import abstractmethod +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING, Any + +from .._dataset_ref import DatasetRef +from .._dataset_type import DatasetType +from ._base import CountableQueryBase, QueryResultsBase +from .driver import QueryDriver +from .result_specs import DatasetRefResultSpec +from .tree import QueryTree + +if TYPE_CHECKING: + from ._data_coordinate_query_results import DataCoordinateQueryResults + + +class DatasetQueryResults(CountableQueryBase, Iterable[DatasetRef]): + """An interface for objects that represent the results of queries for + datasets. + """ + + @abstractmethod + def by_dataset_type(self) -> Iterator[SingleTypeDatasetQueryResults]: + """Group results by dataset type. + + Returns + ------- + iter : `~collections.abc.Iterator` [ `SingleTypeDatasetQueryResults` ] + An iterator over `DatasetQueryResults` instances that are each + responsible for a single dataset type. + """ + raise NotImplementedError() + + @property + @abstractmethod + def has_dimension_records(self) -> bool: + """Whether all data IDs in this iterable contain dimension records.""" + raise NotImplementedError() + + @abstractmethod + def with_dimension_records(self) -> DatasetQueryResults: + """Return a results object for which `has_dimension_records` is + `True`. + """ + raise NotImplementedError() + + +class SingleTypeDatasetQueryResults(DatasetQueryResults, QueryResultsBase): + """A method-chaining builder for butler queries that return `DatasetRef` + objects. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. + The instance returned directly by the `Butler._query` entry point + should be constructed via `make_unit_query_tree`. + spec : `DatasetRefResultSpec` + Specification of the query result rows, including output columns, + ordering, and slicing. + + Notes + ----- + This refines the `SingleTypeDatasetQueryResults` ABC defined in + `lsst.daf.butler._query_results`, but the intent is to replace that ABC + with this concrete class, rather than inherit from it. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree, spec: DatasetRefResultSpec): + spec.validate_tree(tree) + super().__init__(driver, tree) + self._spec = spec + + def __iter__(self) -> Iterator[DatasetRef]: + page = self._driver.execute(self._spec, self._tree) + yield from page.rows + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield from page.rows + + @property + def dataset_type(self) -> DatasetType: + # Docstring inherited. + return DatasetType(self._spec.dataset_type_name, self._spec.dimensions, self._spec.storage_class_name) + + @property + def data_ids(self) -> DataCoordinateQueryResults: + # Docstring inherited. + from ._data_coordinate_query_results import DataCoordinateQueryResults, DataCoordinateResultSpec + + return DataCoordinateQueryResults( + self._driver, + tree=self._tree, + spec=DataCoordinateResultSpec.model_construct( + dimensions=self.dataset_type.dimensions.as_group(), + include_dimension_records=self._spec.include_dimension_records, + ), + ) + + @property + def has_dimension_records(self) -> bool: + # Docstring inherited. + return self._spec.include_dimension_records + + def with_dimension_records(self) -> SingleTypeDatasetQueryResults: + # Docstring inherited. + if self.has_dimension_records: + return self + return self._copy(tree=self._tree, include_dimension_records=True) + + def by_dataset_type(self) -> Iterator[SingleTypeDatasetQueryResults]: + # Docstring inherited. + return iter((self,)) + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + # Docstring inherited. + return self._driver.count( + self._tree, + self._spec.get_result_columns(), + find_first_dataset=self._spec.find_first_dataset, + exact=exact, + discard=discard, + ) + + def _copy(self, tree: QueryTree, **kwargs: Any) -> SingleTypeDatasetQueryResults: + return SingleTypeDatasetQueryResults( + self._driver, + self._tree, + self._spec.model_copy(update=kwargs), + ) + + def _get_datasets(self) -> frozenset[str]: + return frozenset({self.dataset_type.name}) + + +class ChainedDatasetQueryResults(DatasetQueryResults): + """Implementation of `DatasetQueryResults` that delegates to a sequence + of `SingleTypeDatasetQueryResults`. + + Parameters + ---------- + by_dataset_type : `tuple` [ `SingleTypeDatasetQueryResults` ] + Tuple of single-dataset-type query result objects to combine. + + Notes + ----- + Ideally this will eventually just be "DatasetQueryResults", because we + won't need an ABC if this is the only implementation. + """ + + def __init__(self, by_dataset_type: tuple[SingleTypeDatasetQueryResults, ...]): + self._by_dataset_type = by_dataset_type + + def __iter__(self) -> Iterator[DatasetRef]: + return itertools.chain.from_iterable(self._by_dataset_type) + + def by_dataset_type(self) -> Iterator[SingleTypeDatasetQueryResults]: + # Docstring inherited. + return iter(self._by_dataset_type) + + @property + def has_dimension_records(self) -> bool: + # Docstring inherited. + return all(single_type_results.has_dimension_records for single_type_results in self._by_dataset_type) + + def with_dimension_records(self) -> ChainedDatasetQueryResults: + # Docstring inherited. + return ChainedDatasetQueryResults( + tuple( + [ + single_type_results.with_dimension_records() + for single_type_results in self._by_dataset_type + ] + ) + ) + + def any(self, *, execute: bool = True, exact: bool = True) -> bool: + # Docstring inherited. + return any( + single_type_results.any(execute=execute, exact=exact) + for single_type_results in self._by_dataset_type + ) + + def explain_no_results(self, execute: bool = True) -> Iterable[str]: + # Docstring inherited. + messages: list[str] = [] + for single_type_results in self._by_dataset_type: + messages.extend(single_type_results.explain_no_results(execute=execute)) + return messages + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + return sum( + single_type_results.count(exact=exact, discard=discard) + for single_type_results in self._by_dataset_type + ) diff --git a/python/lsst/daf/butler/queries/_dimension_record_query_results.py b/python/lsst/daf/butler/queries/_dimension_record_query_results.py new file mode 100644 index 0000000000..6663fcc0af --- /dev/null +++ b/python/lsst/daf/butler/queries/_dimension_record_query_results.py @@ -0,0 +1,109 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("DimensionRecordQueryResults",) + +from collections.abc import Iterator +from typing import Any + +from ..dimensions import DimensionElement, DimensionRecord, DimensionRecordSet, DimensionRecordTable +from ._base import QueryResultsBase +from .driver import QueryDriver +from .result_specs import DimensionRecordResultSpec +from .tree import QueryTree + + +class DimensionRecordQueryResults(QueryResultsBase): + """A method-chaining builder for butler queries that return data IDs. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. + The instance returned directly by the `Butler._query` entry point + should be constructed via `make_unit_query_tree`. + spec : `DimensionRecordResultSpec` + Specification of the query result rows, including output columns, + ordering, and slicing. + + Notes + ----- + This refines the `DimensionRecordQueryResults` ABC defined in + `lsst.daf.butler._query_results`, but the intent is to replace that ABC + with this concrete class, rather than inherit from it. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree, spec: DimensionRecordResultSpec): + spec.validate_tree(tree) + super().__init__(driver, tree) + self._spec = spec + + def __iter__(self) -> Iterator[DimensionRecord]: + page = self._driver.execute(self._spec, self._tree) + yield from page.rows + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield from page.rows + + def iter_table_pages(self) -> Iterator[DimensionRecordTable]: + page = self._driver.execute(self._spec, self._tree) + yield page.as_table() + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield page.as_table() + + def iter_set_pages(self) -> Iterator[DimensionRecordSet]: + page = self._driver.execute(self._spec, self._tree) + yield page.as_set() + while page.next_key is not None: + page = self._driver.fetch_next_page(self._spec, page.next_key) + yield page.as_set() + + @property + def element(self) -> DimensionElement: + # Docstring inherited. + return self._spec.element + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + # Docstring inherited. + return self._driver.count( + self._tree, + self._spec.get_result_columns(), + find_first_dataset=None, + exact=exact, + discard=discard, + ) + + def _copy(self, tree: QueryTree, **kwargs: Any) -> DimensionRecordQueryResults: + return DimensionRecordQueryResults(self._driver, tree, self._spec.model_copy(update=kwargs)) + + def _get_datasets(self) -> frozenset[str]: + return frozenset() diff --git a/python/lsst/daf/butler/queries/_query.py b/python/lsst/daf/butler/queries/_query.py new file mode 100644 index 0000000000..c8ee728a87 --- /dev/null +++ b/python/lsst/daf/butler/queries/_query.py @@ -0,0 +1,511 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("Query",) + +from collections.abc import Iterable, Mapping, Set +from types import EllipsisType +from typing import Any, overload + +from lsst.utils.iteration import ensure_iterable + +from .._dataset_type import DatasetType +from ..dimensions import DataCoordinate, DataId, DataIdValue, DimensionGroup +from ..registry import DatasetTypeError, MissingDatasetTypeError +from ._base import HomogeneousQueryBase +from ._data_coordinate_query_results import DataCoordinateQueryResults +from ._dataset_query_results import ( + ChainedDatasetQueryResults, + DatasetQueryResults, + SingleTypeDatasetQueryResults, +) +from ._dimension_record_query_results import DimensionRecordQueryResults +from .convert_args import convert_where_args +from .driver import QueryDriver +from .expression_factory import ExpressionFactory +from .result_specs import DataCoordinateResultSpec, DatasetRefResultSpec, DimensionRecordResultSpec +from .tree import ( + DatasetSearch, + InvalidQueryTreeError, + Predicate, + QueryTree, + make_dimension_query_tree, + make_unit_query_tree, +) + + +class Query(HomogeneousQueryBase): + """A method-chaining builder for butler queries. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. + The instance returned directly by the `Butler._query` entry point + should be constructed via `make_unit_query_tree`. + + Notes + ----- + This largely mimics and considerably expands the `Query` ABC defined in + `lsst.daf.butler._query`, but the intent is to replace that ABC with this + concrete class, rather than inherit from it. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree): + super().__init__(driver, tree) + + @property + def dataset_types(self) -> Set[str]: + """The names of all dataset types joined into the query. + + These dataset types are usable in 'where' expressions, but may or may + not be available to result rows. + """ + return self._tree.datasets.keys() + + @property + def expression_factory(self) -> ExpressionFactory: + """A factory for column expressions using overloaded operators. + + Notes + ----- + Typically this attribute will be assigned to a single-character local + variable, and then its (dynamic) attributes can be used to obtain + references to columns that can be included in a query:: + + with butler._query() as query: + x = query.expression_factory + query = query.where( + x.instrument == "LSSTCam", + x.visit.day_obs > 20240701, + x.any(x.band == 'u', x.band == 'y'), + ) + + As shown above, the returned object also has an `any` method to create + combine expressions with logical OR (as well as `not_` and `all`, + though the latter is rarely necessary since `where` already combines + its arguments with AND). + + Proxies for fields associated with dataset types (``dataset_id``, + ``ingest_date``, ``run``, ``collection``, as well as ``timespan`` for + `~CollectionType.CALIBRATION` collection searches) can be obtained with + dict-like access instead:: + + with butler._query() as query: + query = query.order_by(x["raw"].ingest_date) + + Expression proxy objects that correspond to scalar columns overload the + standard comparison operators (``==``, ``!=``, ``<``, ``>``, ``<=``, + ``>=``) and provide `~ScalarExpressionProxy.in_range`, + `~ScalarExpressionProxy.in_iterable`, and + `~ScalarExpressionProxy.in_query` methods for membership tests. For + `order_by` contexts, they also have a `~ScalarExpressionProxy.desc` + property to indicate that the sort order for that expression should be + reversed. + + Proxy objects for region and timespan fields have an `overlaps` method, + and timespans also have `~TimespanProxy.begin` and `~TimespanProxy.end` + properties to access scalar expression proxies for the bounds. + + All proxy objects also have a `~ExpressionProxy.is_null` property. + + Literal values can be created by calling `ExpressionFactory.literal`, + but can almost always be created implicitly via overloaded operators + instead. + """ + return ExpressionFactory(self._driver.universe) + + def data_ids( + self, + dimensions: DimensionGroup | Iterable[str] | str, + ) -> DataCoordinateQueryResults: + """Query for data IDs matching user-provided criteria. + + Parameters + ---------- + dimensions : `DimensionGroup`, `str`, or \ + `~collections.abc.Iterable` [`str`] + The dimensions of the data IDs to yield, as either `DimensionGroup` + instances or `str`. Will be automatically expanded to a complete + `DimensionGroup`. + + Returns + ------- + dataIds : `DataCoordinateQueryResults` + Data IDs matching the given query parameters. These are guaranteed + to identify all dimensions (`DataCoordinate.hasFull` returns + `True`), but will not contain `DimensionRecord` objects + (`DataCoordinate.hasRecords` returns `False`). Call + `~DataCoordinateQueryResults.with_dimension_records` on the + returned object to fetch those. + """ + dimensions = self._driver.universe.conform(dimensions) + tree = self._tree + if not dimensions >= self._tree.dimensions: + tree = tree.join(make_dimension_query_tree(dimensions)) + result_spec = DataCoordinateResultSpec(dimensions=dimensions, include_dimension_records=False) + return DataCoordinateQueryResults(self._driver, tree, result_spec) + + @overload + def datasets( + self, + dataset_type: str | DatasetType, + collections: str | Iterable[str] | None = None, + *, + find_first: bool = True, + ) -> SingleTypeDatasetQueryResults: ... + + @overload + def datasets( + self, + dataset_type: Iterable[str | DatasetType] | EllipsisType, + collections: str | Iterable[str] | None = None, + *, + find_first: bool = True, + ) -> DatasetQueryResults: ... + + def datasets( + self, + dataset_type: str | DatasetType | Iterable[str | DatasetType] | EllipsisType, + collections: str | Iterable[str] | None = None, + *, + find_first: bool = True, + ) -> DatasetQueryResults: + """Query for and iterate over dataset references matching user-provided + criteria. + + Parameters + ---------- + dataset_type : `str`, `DatasetType`, \ + `~collections.abc.Iterable` [ `str` or `DatasetType` ], \ + or ``...`` + The dataset type or types to search for. Passing ``...`` searches + for all datasets in the given collections. + collections : `str` or `~collections.abc.Iterable` [ `str` ], optional + The collection or collections to search, in order. If not provided + or `None`, and the dataset has not already been joined into the + query, the default collection search path for this butler is used. + find_first : `bool`, optional + If `True` (default), for each result data ID, only yield one + `DatasetRef` of each `DatasetType`, from the first collection in + which a dataset of that dataset type appears (according to the + order of ``collections`` passed in). If `True`, ``collections`` + must not contain regular expressions and may not be ``...``. + + Returns + ------- + refs : `.queries.DatasetQueryResults` + Dataset references matching the given query criteria. Nested data + IDs are guaranteed to include values for all implied dimensions + (i.e. `DataCoordinate.hasFull` will return `True`), but will not + include dimension records (`DataCoordinate.hasRecords` will be + `False`) unless + `~.queries.DatasetQueryResults.with_dimension_records` is + called on the result object (which returns a new one). + + Raises + ------ + lsst.daf.butler.registry.DatasetTypeExpressionError + Raised when ``dataset_type`` expression is invalid. + TypeError + Raised when the arguments are incompatible, such as when a + collection wildcard is passed when ``find_first`` is `True`, or + when ``collections`` is `None` and default butler collections are + not defined. + + Notes + ----- + When multiple dataset types are queried in a single call, the + results of this operation are equivalent to querying for each dataset + type separately in turn, and no information about the relationships + between datasets of different types is included. + """ + if collections is None: + collections = self._driver.get_default_collections() + collections = tuple(ensure_iterable(collections)) + resolved_dataset_searches = self._driver.convert_dataset_search_args(dataset_type, collections) + single_type_results: list[SingleTypeDatasetQueryResults] = [] + for resolved_dataset_type in resolved_dataset_searches: + tree = self._tree + if resolved_dataset_type.name not in tree.datasets: + tree = tree.join_dataset( + resolved_dataset_type.name, + DatasetSearch.model_construct( + dimensions=resolved_dataset_type.dimensions.as_group(), + collections=collections, + ), + ) + elif collections is not None: + raise InvalidQueryTreeError( + f"Dataset type {resolved_dataset_type.name!r} was already joined into this query " + f"but new collections {collections!r} were still provided." + ) + spec = DatasetRefResultSpec.model_construct( + dataset_type_name=resolved_dataset_type.name, + dimensions=resolved_dataset_type.dimensions.as_group(), + storage_class_name=resolved_dataset_type.storageClass_name, + include_dimension_records=False, + find_first=find_first, + ) + single_type_results.append(SingleTypeDatasetQueryResults(self._driver, tree=tree, spec=spec)) + if len(single_type_results) == 1: + return single_type_results[0] + else: + return ChainedDatasetQueryResults(tuple(single_type_results)) + + def dimension_records(self, element: str) -> DimensionRecordQueryResults: + """Query for dimension information matching user-provided criteria. + + Parameters + ---------- + element : `str` + The name of a dimension element to obtain records for. + + Returns + ------- + records : `.queries.DimensionRecordQueryResults` + Data IDs matching the given query parameters. + """ + tree = self._tree + if element not in tree.dimensions.elements: + tree = tree.join(make_dimension_query_tree(self._driver.universe[element].minimal_group)) + result_spec = DimensionRecordResultSpec(element=self._driver.universe[element]) + return DimensionRecordQueryResults(self._driver, tree, result_spec) + + # TODO: add general, dict-row results method and QueryResults. + + def materialize( + self, + *, + dimensions: Iterable[str] | DimensionGroup | None = None, + datasets: Iterable[str] | None = None, + ) -> Query: + """Execute the query, save its results to a temporary location, and + return a new query that represents fetching or joining against those + saved results. + + Parameters + ---------- + dimensions : `~collections.abc.Iterable` [ `str` ] or \ + `DimensionGroup`, optional + Dimensions to include in the temporary results. Default is to + include all dimensions in the query. + datasets : `~collections.abc.Iterable` [ `str` ], optional + Names of dataset types that should be included in the new query; + default is to include `result_dataset_types`. Only resolved + dataset UUIDs will actually be materialized; datasets whose UUIDs + cannot be resolved will continue to be represented in the query via + a join on their dimensions. + + Returns + ------- + query : `Query` + A new query object whose that represents the materialized rows. + """ + if datasets is None: + datasets = frozenset(self.dataset_types) + else: + datasets = frozenset(datasets) + if not (datasets <= self.dataset_types): + raise InvalidQueryTreeError( + f"Dataset(s) {datasets - self.dataset_types} are present in the query." + ) + if dimensions is None: + dimensions = self._tree.dimensions + else: + dimensions = self._driver.universe.conform(dimensions) + key = self._driver.materialize(self._tree, dimensions, datasets) + tree = make_unit_query_tree(self._driver.universe).join_materialization(key, dimensions=dimensions) + for dataset_type_name in datasets: + tree = tree.join_dataset(dataset_type_name, self._tree.datasets[dataset_type_name]) + return Query(self._driver, tree) + + def join_dataset_search( + self, + dataset_type: str, + collections: Iterable[str] | None = None, + dimensions: DimensionGroup | None = None, + ) -> Query: + """Return a new query with a search for a dataset joined in. + + Parameters + ---------- + dataset_type : `str` + Name of the dataset type. May not refer to a dataset component. + collections : `~collections.abc.Iterable` [ `str` ], optional + Iterable of collections to search. Order is preserved, but will + not matter if the dataset search is only used as a constraint on + dimensions or if ``find_first=False`` when requesting results. If + not present or `None`, the default collection search path will be + used. + dimensions : `DimensionGroup`, optional + The dimensions to assume for the dataset type if it is not + registered, or check if it is. When the dataset is not registered + and this is not provided, `MissingDatasetTypeError` is raised, + since we cannot construct a query without knowing the dataset's + dimensions; providing this argument causes the returned query to + instead return no rows. + + Returns + ------- + query : `Query` + A new query object with dataset columns available and rows + restricted to those consistent with the found data IDs. + + Raises + ------ + DatasetTypeError + Raised if the dimensions were provided but they do not match the + registered dataset type. + MissingDatasetTypeError + Raised if the dimensions were not provided and the dataset type was + not registered. + """ + if collections is None: + collections = self._driver.get_default_collections() + collections = tuple(ensure_iterable(collections)) + assert isinstance(dataset_type, str), "DatasetType instances not supported here for simplicity." + try: + resolved_dimensions = self._driver.get_dataset_type(dataset_type).dimensions.as_group() + except MissingDatasetTypeError: + if dimensions is None: + raise + resolved_dimensions = dimensions + else: + if dimensions is not None and dimensions != resolved_dimensions: + raise DatasetTypeError( + f"Given dimensions {dimensions} for dataset type {dataset_type!r} do not match the " + f"registered dimensions {resolved_dimensions}." + ) + return Query( + tree=self._tree.join_dataset( + dataset_type, + DatasetSearch.model_construct(collections=collections, dimensions=resolved_dimensions), + ), + driver=self._driver, + ) + + def join_data_coordinates(self, iterable: Iterable[DataCoordinate]) -> Query: + """Return a new query that joins in an explicit table of data IDs. + + Parameters + ---------- + iterable : `~collections.abc.Iterable` [ `DataCoordinate` ] + Iterable of `DataCoordinate`. All items must have the same + dimensions. Must have at least one item. + + Returns + ------- + query : `Query` + A new query object with the data IDs joined in. + """ + rows: set[tuple[DataIdValue, ...]] = set() + dimensions: DimensionGroup | None = None + for data_coordinate in iterable: + if dimensions is None: + dimensions = data_coordinate.dimensions + elif dimensions != data_coordinate.dimensions: + raise RuntimeError(f"Inconsistent dimensions: {dimensions} != {data_coordinate.dimensions}.") + rows.add(data_coordinate.required_values) + if dimensions is None: + raise RuntimeError("Cannot upload an empty data coordinate set.") + key = self._driver.upload_data_coordinates(dimensions, rows) + return Query( + tree=self._tree.join_data_coordinate_upload(dimensions=dimensions, key=key), driver=self._driver + ) + + def join_dimensions(self, dimensions: Iterable[str] | DimensionGroup) -> Query: + """Return a new query that joins the logical tables for additional + dimensions. + + Parameters + ---------- + dimensions : `~collections.abc.Iterable` [ `str` ] or `DimensionGroup` + Names of dimensions to join in. + + Returns + ------- + query : `Query` + A new query object with the dimensions joined in. + """ + dimensions = self._driver.universe.conform(dimensions) + return Query( + tree=self._tree.join(make_dimension_query_tree(dimensions)), + driver=self._driver, + ) + + def where( + self, + *args: str | Predicate | DataId, + bind: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> Query: + """Return a query with a boolean-expression filter on its rows. + + Parameters + ---------- + *args + Constraints to apply, combined with logical AND. Arguments may be + `str` expressions to parse, `Predicate` objects (these are + typically constructed via `expression_factory`) or data IDs. + bind : `~collections.abc.Mapping` + Mapping from string identifier appearing in a string expression to + a literal value that should be substituted for it. This is + recommended instead of embedding literals directly into the + expression, especially for strings, timespans, or other types where + quoting or formatting is nontrivial. + **kwargs + Data ID key value pairs that extend and override any present in + ``*args``. + + Returns + ------- + query : `Query` + A new query object with the given row filters as well as any + already present in ``self`` (combined with logical AND). + + Notes + ----- + If an expression references a dimension or dimension element that is + not already present in the query, it will be joined in, but dataset + searches must already be joined into a query in order to reference + their fields in expressions. + + Data ID values are not checked for consistency; they are extracted from + ``args`` and then ``kwargs`` and combined, with later values overriding + earlier ones. + """ + return Query( + tree=self._tree.where( + *convert_where_args(self.dimensions, self.dataset_types, *args, bind=bind, **kwargs) + ), + driver=self._driver, + ) diff --git a/python/lsst/daf/butler/queries/convert_args.py b/python/lsst/daf/butler/queries/convert_args.py new file mode 100644 index 0000000000..b4f18e9045 --- /dev/null +++ b/python/lsst/daf/butler/queries/convert_args.py @@ -0,0 +1,244 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "convert_where_args", + "convert_order_by_args", +) + +import itertools +from collections.abc import Mapping, Set +from typing import Any, cast + +from ..dimensions import DataCoordinate, DataId, Dimension, DimensionGroup +from .expression_factory import ExpressionProxy +from .tree import ( + DATASET_FIELD_NAMES, + ColumnExpression, + DatasetFieldName, + DatasetFieldReference, + DimensionFieldReference, + DimensionKeyReference, + InvalidQueryTreeError, + OrderExpression, + Predicate, + Reversed, + make_column_literal, +) + + +def convert_where_args( + dimensions: DimensionGroup, + datasets: Set[str], + *args: str | Predicate | DataId, + bind: Mapping[str, Any] | None = None, + **kwargs: Any, +) -> Predicate: + """Convert ``where`` arguments to a sequence of column expressions. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions already present in the query this filter is being applied + to. Returned predicates may reference dimensions outside this set. + datasets : `~collections.abc.Set` [ `str` ] + Dataset types already present in the query this filter is being applied + to. Returned predicates may still reference datasets outside this set; + this may be an error at a higher level, but it is not necessarily + checked here. + *args : `str`, `Predicate`, `DataCoordinate`, or `~collections.abc.Mapping` + Expressions to convert into predicates. + bind : `~collections.abc.Mapping`, optional + Mapping from identifier to literal value used when parsing string + expressions. + **kwargs : `object` + Additional data ID key-value pairs. + + Returns + ------- + predicate : `Predicate` + Standardized predicate object. + + Notes + ----- + Data ID values are not checked for consistency; they are extracted from + args and then kwargs and combined, with later extractions taking + precedence. + """ + result = Predicate.from_bool(True) + data_id_dict: dict[str, Any] = {} + for arg in args: + match arg: + case str(): + raise NotImplementedError("TODO: plug in registry.queries.expressions.parser") + case Predicate(): + result = result.logical_and(arg) + case DataCoordinate(): + data_id_dict.update(arg.mapping) + case _: + data_id_dict.update(arg) + data_id_dict.update(kwargs) + for k, v in data_id_dict.items(): + result = result.logical_and( + Predicate.compare( + DimensionKeyReference.model_construct(dimension=dimensions.universe.dimensions[k]), + "==", + make_column_literal(v), + ) + ) + return result + + +def convert_order_by_args( + dimensions: DimensionGroup, datasets: Set[str], *args: str | OrderExpression | ExpressionProxy +) -> tuple[OrderExpression, ...]: + """Convert ``order_by`` arguments to a sequence of column expressions. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions already present in the query whose rows are being sorted. + Returned expressions may reference dimensions outside this set; this + may be an error at a higher level, but it is not necessarily checked + here. + datasets : `~collections.abc.Set` [ `str` ] + Dataset types already present in the query whose rows are being sorted. + Returned expressions may reference datasets outside this set; this may + be an error at a higher level, but it is not necessarily checked here. + *args : `OrderExpression`, `str`, or `ExpressionObject` + Expression or column names to sort by. + + Returns + ------- + expressions : `tuple` [ `OrderExpression`, ... ] + Standardized expression objects. + """ + result: list[OrderExpression] = [] + for arg in args: + match arg: + case str(): + reverse = False + if arg.startswith("-"): + reverse = True + arg = arg[1:] + arg = interpret_identifier(dimensions, datasets, arg, {}) + if reverse: + arg = Reversed.model_construct(operand=arg) + case ExpressionProxy(): + arg = arg._expression + if not hasattr(arg, "expression_type"): + raise TypeError(f"Unrecognized order-by argument: {arg!r}.") + result.append(arg) + return tuple(result) + + +def interpret_identifier( + dimensions: DimensionGroup, datasets: Set[str], identifier: str, bind: Mapping[str, Any] +) -> ColumnExpression: + """Associate an identifier in a ``where`` or ``order_by`` expression with + a query column or bind literal. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions already present in the query this filter is being applied + to. Returned expressions may reference dimensions outside this set. + datasets : `~collections.abc.Set` [ `str` ] + Dataset types already present in the query this filter is being applied + to. Returned expressions may still reference datasets outside this + set. + identifier : `str` + String identifier to process. + bind : `~collections.abc.Mapping` [ `str`, `object` ] + Dictionary of bind literals to match identifiers against first. + + Returns + ------- + expression : `ColumnExpression` + Column expression corresponding to the identifier. + """ + if identifier in bind: + return make_column_literal(bind[identifier]) + first, _, second = identifier.partition(".") + if not second: + if first in dimensions.universe.dimensions: + return DimensionKeyReference.model_construct(dimension=dimensions.universe.dimensions[first]) + else: + element_matches: set[str] = set() + for element_name in dimensions.elements: + element = dimensions.universe[element_name] + if first in element.schema.names: + element_matches.add(element_name) + if first in DATASET_FIELD_NAMES: + dataset_matches = set(datasets) + else: + dataset_matches = set() + if len(element_matches) + len(dataset_matches) > 1: + match_str = ", ".join( + f"'{x}.{first}'" for x in sorted(itertools.chain(element_matches, dataset_matches)) + ) + raise InvalidQueryTreeError( + f"Ambiguous identifier {first!r} matches multiple fields: {match_str}." + ) + elif element_matches: + element = dimensions.universe[element_matches.pop()] + return DimensionFieldReference.model_construct(element=element, field=first) + elif dataset_matches: + return DatasetFieldReference.model_construct( + dataset_type=dataset_matches.pop(), field=cast(DatasetFieldName, first) + ) + else: + if first in dimensions.universe.elements: + element = dimensions.universe[first] + if second in element.schema.dimensions.names: + if isinstance(element, Dimension) and second == element.primary_key.name: + # Identifier is something like "visit.id" which we want to + # interpret the same way as just "visit". + return DimensionKeyReference.model_construct(dimension=element) + else: + # Identifier is something like "visit.instrument", which we + # want to interpret the same way as just "instrument". + dimension = dimensions.universe.dimensions[second] + return DimensionKeyReference.model_construct(dimension=dimension) + elif second in element.schema.remainder.names: + return DimensionFieldReference.model_construct(element=element, field=second) + else: + raise InvalidQueryTreeError(f"Unrecognized field {second!r} for {first}.") + elif second in DATASET_FIELD_NAMES: + # We just assume the dataset type is okay; it's the job of + # higher-level code to complain othewise. + return DatasetFieldReference.model_construct( + dataset_type=first, field=cast(DatasetFieldName, second) + ) + elif first in datasets: + raise InvalidQueryTreeError( + f"Identifier {identifier!r} references dataset type {first!r} but field " + f"{second!r} is not a valid for datasets." + ) + raise InvalidQueryTreeError(f"Unrecognized identifier {identifier!r}.") diff --git a/python/lsst/daf/butler/queries/driver.py b/python/lsst/daf/butler/queries/driver.py new file mode 100644 index 0000000000..3a07e2dc42 --- /dev/null +++ b/python/lsst/daf/butler/queries/driver.py @@ -0,0 +1,512 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "QueryDriver", + "PageKey", + "ResultPage", + "DataCoordinateResultPage", + "DimensionRecordResultPage", + "DatasetRefResultPage", + "GeneralResultPage", +) + +import dataclasses +import uuid +from abc import abstractmethod +from collections.abc import Iterable, Sequence +from contextlib import AbstractContextManager +from types import EllipsisType +from typing import Annotated, Any, TypeAlias, Union, overload + +import pydantic +from lsst.utils.iteration import ensure_iterable + +from .._dataset_ref import DatasetRef +from .._dataset_type import DatasetType +from ..dimensions import ( + DataCoordinate, + DataIdValue, + DimensionGroup, + DimensionRecord, + DimensionRecordSet, + DimensionRecordTable, + DimensionUniverse, +) +from ..registry import CollectionSummary, DatasetTypeError, DatasetTypeExpressionError +from ..registry.interfaces import CollectionRecord +from .result_specs import ( + DataCoordinateResultSpec, + DatasetRefResultSpec, + DimensionRecordResultSpec, + GeneralResultSpec, + ResultSpec, +) +from .tree import ColumnSet, DataCoordinateUploadKey, MaterializationKey, QueryTree + +PageKey: TypeAlias = uuid.UUID + + +class DataCoordinateResultPage(pydantic.BaseModel): + """A single page of results from a data coordinate query.""" + + spec: DataCoordinateResultSpec + next_key: PageKey | None + + # TODO: On DM-41114 this will become a custom container that normalizes out + # attached DimensionRecords and is Pydantic-friendly. Right now this model + # isn't actually serializable. + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + rows: list[DataCoordinate] + + +@dataclasses.dataclass +class DimensionRecordResultPage: + """A single page of results from a dimension record query.""" + + spec: DimensionRecordResultSpec + next_key: PageKey | None + rows: Iterable[DimensionRecord] + + def as_table(self) -> DimensionRecordTable: + if isinstance(self.rows, DimensionRecordTable): + return self.rows + else: + return DimensionRecordTable(self.spec.element, self.rows) + + def as_set(self) -> DimensionRecordSet: + if isinstance(self.rows, DimensionRecordSet): + return self.rows + else: + return DimensionRecordSet(self.spec.element, self.rows) + + +class DatasetRefResultPage(pydantic.BaseModel): + """A single page of results from a dataset ref query.""" + + spec: DatasetRefResultSpec + next_key: PageKey | None + + # TODO: On DM-41115 this will become a custom container that normalizes out + # attached DimensionRecords and is Pydantic-friendly. Right now this model + # isn't actually serializable. + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + rows: list[DatasetRef] + + +class GeneralResultPage(pydantic.BaseModel): + """A single page of results from a general query.""" + + spec: GeneralResultSpec + next_key: PageKey | None + + # Raw tabular data, with columns in the same order as spec.columns. + rows: list[tuple[Any, ...]] + + +ResultPage: TypeAlias = Annotated[ + Union[DataCoordinateResultPage, DimensionRecordResultPage, DatasetRefResultPage, GeneralResultPage], + pydantic.Field(discriminator=lambda x: x.spec.result_type), +] + + +class QueryDriver(AbstractContextManager[None]): + """Base class for the implementation object inside `Query2` objects + that is specialized for DirectButler vs. RemoteButler. + + Notes + ----- + Implementations should be context managers. This allows them to manage the + lifetime of server-side state, such as: + + - a SQL transaction, when necessary (DirectButler); + - SQL cursors for queries that were not fully iterated over (DirectButler); + - temporary database tables (DirectButler); + - result-page Parquet files that were never fetched (RemoteButler); + - uploaded Parquet files used to fill temporary database tables + (RemoteButler); + - cached content needed to construct query trees, like collection summaries + (potentially all Butlers). + + When possible, these sorts of things should be cleaned up earlier when they + are no longer needed, and the Butler server will still have to guard + against the context manager's ``__exit__`` signal never reaching it, but a + context manager will take care of these much more often than relying on + garbage collection and ``__del__`` would. + """ + + @property + @abstractmethod + def universe(self) -> DimensionUniverse: + """Object that defines all dimensions.""" + raise NotImplementedError() + + @overload + def execute(self, result_spec: DataCoordinateResultSpec, tree: QueryTree) -> DataCoordinateResultPage: ... + + @overload + def execute( + self, result_spec: DimensionRecordResultSpec, tree: QueryTree + ) -> DimensionRecordResultPage: ... + + @overload + def execute(self, result_spec: DatasetRefResultSpec, tree: QueryTree) -> DatasetRefResultPage: ... + + @overload + def execute(self, result_spec: GeneralResultSpec, tree: QueryTree) -> GeneralResultPage: ... + + @abstractmethod + def execute(self, result_spec: ResultSpec, tree: QueryTree) -> ResultPage: + """Execute a query and return the first result page. + + Parameters + ---------- + result_spec : `ResultSpec` + The kind of results the user wants from the query. This can affect + the actual query (i.e. SQL and Python postprocessing) that is run, + e.g. by changing what is in the SQL SELECT clause and even what + tables are joined in, but it never changes the number or order of + result rows. + tree : `QueryTree` + Query tree to evaluate. + + Returns + ------- + first_page : `ResultPage` + A page whose type corresponds to the type of ``result_spec``, with + at least the initial rows from the query. This should have an + empty ``rows`` attribute if the query returned no results, and a + ``next_key`` attribute that is not `None` if there were more + results than could be returned in a single page. + """ + raise NotImplementedError() + + @overload + def fetch_next_page( + self, result_spec: DataCoordinateResultSpec, key: PageKey + ) -> DataCoordinateResultPage: ... + + @overload + def fetch_next_page( + self, result_spec: DimensionRecordResultSpec, key: PageKey + ) -> DimensionRecordResultPage: ... + + @overload + def fetch_next_page(self, result_spec: DatasetRefResultSpec, key: PageKey) -> DatasetRefResultPage: ... + + @overload + def fetch_next_page(self, result_spec: GeneralResultSpec, key: PageKey) -> GeneralResultPage: ... + + @abstractmethod + def fetch_next_page(self, result_spec: ResultSpec, key: PageKey) -> ResultPage: + """Fetch the next page of results from an already-executed query. + + Parameters + ---------- + result_spec : `ResultSpec` + The kind of results the user wants from the query. This must be + identical to the ``result_spec`` passed to `execute`, but + implementations are not *required* to check this. + key : `PageKey` + Key included in the previous page from this query. This key may + become unusable or even be reused after this call. + + Returns + ------- + next_page : `ResultPage` + The next page of query results. + """ + # We can put off dealing with pagination initially by just making an + # implementation of this method raise. + # + # In RemoteButler I expect this to work by having the call to execute + # continue to write Parquet files (or whatever) to some location until + # its cursor is exhausted, and then delete those files as they are + # fetched (or, failing that, when receiving a signal from + # ``__exit__``). + # + # In DirectButler I expect to have a dict[PageKey, Cursor], fetch a + # blocks of rows from it, and just reuse the page key for the next page + # until the cursor is exactly. + raise NotImplementedError() + + @abstractmethod + def materialize( + self, + tree: QueryTree, + dimensions: DimensionGroup, + datasets: frozenset[str], + ) -> MaterializationKey: + """Execute a query tree, saving results to temporary storage for use + in later queries. + + Parameters + ---------- + tree : `QueryTree` + Query tree to evaluate. + dimensions : `DimensionGroup` + Dimensions whose key columns should be preserved. + datasets : `frozenset` [ `str` ] + Names of dataset types whose ID columns may be materialized. It + is implementation-defined whether they actually are. + + Returns + ------- + key : `MaterializationKey` + Unique identifier for the result rows that allows them to be + referenced in a `QueryTree`. + """ + raise NotImplementedError() + + @abstractmethod + def upload_data_coordinates( + self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]] + ) -> DataCoordinateUploadKey: + """Upload a table of data coordinates for use in later queries. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions of the data coordinates. + rows : `Iterable` [ `tuple` ] + Tuples of data coordinate values, covering just the "required" + subset of ``dimensions``. + + Returns + ------- + key + Unique identifier for the upload that allows it to be referenced in + a `QueryTree`. + """ + raise NotImplementedError() + + @abstractmethod + def count( + self, + tree: QueryTree, + columns: ColumnSet, + find_first_dataset: str | None, + *, + exact: bool, + discard: bool, + ) -> int: + """Return the number of rows a query would return. + + Parameters + ---------- + tree : `QueryTree` + Query tree to evaluate. + columns : `ColumnSet` + Columns over which rows should have unique values before they are + counted. + find_first_dataset : `str` or `None` + Perform a search for this dataset type to reject all but the first + result in the collection search path for each data ID, before + counting the result rows. + exact : `bool`, optional + If `True`, run the full query and perform post-query filtering if + needed to account for that filtering in the count. If `False`, the + result may be an upper bound. + discard : `bool`, optional + If `True`, compute the exact count even if it would require running + the full query and then throwing away the result rows after + counting them. If `False`, this is an error, as the user would + usually be better off executing the query first to fetch its rows + into a new query (or passing ``exact=False``). Ignored if + ``exact=False``. + """ + raise NotImplementedError() + + @abstractmethod + def any(self, tree: QueryTree, *, execute: bool, exact: bool) -> bool: + """Test whether the query would return any rows. + + Parameters + ---------- + tree : `QueryTree` + Query tree to evaluate. + execute : `bool`, optional + If `True`, execute at least a ``LIMIT 1`` query if it cannot be + determined prior to execution that the query would return no rows. + exact : `bool`, optional + If `True`, run the full query and perform post-query filtering if + needed, until at least one result row is found. If `False`, the + returned result does not account for post-query filtering, and + hence may be `True` even when all result rows would be filtered + out. + + Returns + ------- + any : `bool` + `True` if the query would (or might, depending on arguments) yield + result rows. `False` if it definitely would not. + """ + raise NotImplementedError() + + @abstractmethod + def explain_no_results(self, tree: QueryTree, execute: bool) -> Iterable[str]: + """Return human-readable messages that may help explain why the query + yields no results. + + Parameters + ---------- + tree : `QueryTree` + Query tree to evaluate. + execute : `bool`, optional + If `True` (default) execute simplified versions (e.g. ``LIMIT 1``) + of aspects of the tree to more precisely determine where rows were + filtered out. + + Returns + ------- + messages : `~collections.abc.Iterable` [ `str` ] + String messages that describe reasons the query might not yield any + results. + """ + raise NotImplementedError() + + @abstractmethod + def get_default_collections(self) -> tuple[str, ...]: + """Return the default collection search path. + + Returns + ------- + collections : `tuple` [ `str`, ... ] + The default collection search path as a tuple of `str`. + + Raises + ------ + NoDefaultCollectionError + Raised if there are no default collections. + """ + raise NotImplementedError() + + @abstractmethod + def resolve_collection_path( + self, collections: Sequence[str] + ) -> list[tuple[CollectionRecord, CollectionSummary]]: + """Process a collection search path argument into a `list` of + collection records and summaries. + + Parameters + ---------- + collections : `~collections.abc.Sequence` [ `str` ] + The collection or collections to search. + + Returns + ------- + collection_info : `list` [ `tuple` [ `CollectionRecord`, \ + `CollectionSummary` ] ] + A `list` of pairs of `CollectionRecord` and `CollectionSummary` + that flattens out all `~CollectionType.CHAINED` collections into + their children while maintaining the same order and avoiding + duplicates. + + Raises + ------ + MissingCollectionError + Raised if any collection in ``collections`` does not exist. + + Notes + ----- + Implementations are generally expected to cache the collection records + and summaries they obtain (including the records for + `~CollectionType.CHAINED` collections that are not returned) in order + to optimize multiple calls with collections in common. + """ + raise NotImplementedError() + + @abstractmethod + def get_dataset_type(self, name: str) -> DatasetType: + """Return the dimensions for a dataset type. + + Parameters + ---------- + name : `str` + Name of the dataset type. + + Returns + ------- + dataset_type : `DatasetType` + Dimensions of the dataset type. + + Raises + ------ + MissingDatasetTypeError + Raised if the dataset type is not registered. + """ + raise NotImplementedError() + + def convert_dataset_search_args( + self, + dataset_type: str | DatasetType | Iterable[str | DatasetType] | EllipsisType, + collections: Sequence[str], + ) -> list[DatasetType]: + """Resolve dataset type and collections argument. + + Parameters + ---------- + dataset_type : `str`, `DatasetType`, \ + `~collections.abc.Iterable` [ `str` or `DatasetType` ], \ + or ``...`` + The dataset type or types to search for. Passing ``...`` searches + for all datasets in the given collections. + collections : `~collections.abc.Sequence` [ `str` ] + The collection or collections to search. + + Returns + ------- + resolved : `list` [ `DatasetType` ] + Matching dataset types. + """ + if dataset_type is ...: + dataset_type = set() + for _, summary in self.resolve_collection_path(collections): + dataset_type.update(summary.dataset_types.names) + result: list[DatasetType] = [] + for arg in ensure_iterable(dataset_type): + given_dataset_type: DatasetType | None + if isinstance(arg, str): + dataset_type_name = arg + given_dataset_type = None + elif isinstance(arg, DatasetType): + dataset_type_name = arg.name + given_dataset_type = arg + else: + raise DatasetTypeExpressionError(f"Unsupported object {arg} in dataset type expression.") + resolved_dataset_type: DatasetType = self.get_dataset_type(dataset_type_name) + if given_dataset_type is not None and not given_dataset_type.is_compatible_with( + resolved_dataset_type + ): + raise DatasetTypeError( + f"Given dataset type {given_dataset_type} is not compatible with the " + f"registered version {resolved_dataset_type}." + ) + result.append(resolved_dataset_type) + return result diff --git a/python/lsst/daf/butler/queries/expression_factory.py b/python/lsst/daf/butler/queries/expression_factory.py new file mode 100644 index 0000000000..32cec1e2c5 --- /dev/null +++ b/python/lsst/daf/butler/queries/expression_factory.py @@ -0,0 +1,428 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("ExpressionFactory", "ExpressionProxy", "ScalarExpressionProxy", "TimespanProxy", "RegionProxy") + +from collections.abc import Iterable +from typing import TYPE_CHECKING + +from lsst.sphgeom import Region + +from ..dimensions import DimensionElement, DimensionUniverse +from . import tree + +if TYPE_CHECKING: + from .._timespan import Timespan + from ._query import Query + +# This module uses ExpressionProxy and its subclasses to wrap ColumnExpression, +# but it just returns OrderExpression and Predicate objects directly, because +# we don't need to overload any operators or define any methods on those. + + +class ExpressionProxy: + """A wrapper for column expressions that overloads comparison operators + to return new expression proxies. + + Parameters + ---------- + expression : `tree.ColumnExpression` + Underlying expression object. + """ + + def __init__(self, expression: tree.ColumnExpression): + self._expression = expression + + def __repr__(self) -> str: + return str(self._expression) + + @property + def is_null(self) -> tree.Predicate: + """A boolean expression that tests whether this expression is NULL.""" + return tree.Predicate.is_null(self._expression) + + @staticmethod + def _make_expression(other: object) -> tree.ColumnExpression: + if isinstance(other, ExpressionProxy): + return other._expression + else: + return tree.make_column_literal(other) + + def _make_comparison(self, other: object, operator: tree.ComparisonOperator) -> tree.Predicate: + return tree.Predicate.compare(a=self._expression, b=self._make_expression(other), operator=operator) + + +class ScalarExpressionProxy(ExpressionProxy): + """An `ExpressionProxy` specialized for simple single-value columns.""" + + @property + def desc(self) -> tree.Reversed: + """An ordering expression that indicates that the sort on this + expression should be reversed. + """ + return tree.Reversed.model_construct(operand=self._expression) + + def __eq__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, "==") + + def __ne__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, "!=") + + def __lt__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, "<") + + def __le__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, "<=") + + def __gt__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, ">") + + def __ge__(self, other: object) -> tree.Predicate: # type: ignore[override] + return self._make_comparison(other, ">=") + + def in_range(self, start: int = 0, stop: int | None = None, step: int = 1) -> tree.Predicate: + """Return a boolean expression that tests whether this expression is + within a literal integer range. + + Parameters + ---------- + start : `int`, optional + Lower bound (inclusive) for the slice. + stop : `int` or `None`, optional + Upper bound (exclusive) for the slice, or `None` for no bound. + step : `int`, optional + Spacing between integers in the range. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return tree.Predicate.in_range(self._expression, start=start, stop=stop, step=step) + + def in_iterable(self, others: Iterable) -> tree.Predicate: + """Return a boolean expression that tests whether this expression + evaluates to a value that is in an iterable of other expressions. + + Parameters + ---------- + others : `collections.abc.Iterable` + An iterable of `ExpressionProxy` or values to be interpreted as + literals. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return tree.Predicate.in_container(self._expression, [self._make_expression(item) for item in others]) + + def in_query(self, column: ExpressionProxy, query: Query) -> tree.Predicate: + """Return a boolean expression that test whether this expression + evaluates to a value that is in a single-column selection from another + query. + + Parameters + ---------- + column : `ExpressionProxy` + Proxy for the column to extract from ``query``. + query : `RelationQuery` + Query to select from. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return tree.Predicate.in_query_tree(self._expression, column._expression, query._tree) + + +class TimespanProxy(ExpressionProxy): + """An `ExpressionProxy` specialized for timespan columns and literals.""" + + @property + def begin(self) -> ExpressionProxy: + """An expression representing the lower bound (inclusive).""" + return ExpressionProxy( + tree.UnaryExpression.model_construct(operand=self._expression, operator="begin_of") + ) + + @property + def end(self) -> ExpressionProxy: + """An expression representing the upper bound (exclusive).""" + return ExpressionProxy( + tree.UnaryExpression.model_construct(operand=self._expression, operator="end_of") + ) + + def overlaps(self, other: TimespanProxy | Timespan) -> tree.Predicate: + """Return a boolean expression representing an overlap test between + this timespan and another. + + Parameters + ---------- + other : `TimespanProxy` or `Timespan` + Expression or literal to compare to. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return self._make_comparison(other, "overlaps") + + +class RegionProxy(ExpressionProxy): + """An `ExpressionProxy` specialized for region columns and literals.""" + + def overlaps(self, other: RegionProxy | Region) -> tree.Predicate: + """Return a boolean expression representing an overlap test between + this region and another. + + Parameters + ---------- + other : `RegionProxy` or `Region` + Expression or literal to compare to. + + Returns + ------- + predicate : `tree.Predicate` + Boolean expression object. + """ + return self._make_comparison(other, "overlaps") + + +class DimensionElementProxy: + """An expression-creation proxy for a dimension element logical table. + + Parameters + ---------- + element : `DimensionElement` + Element this object wraps. + + Notes + ----- + The (dynamic) attributes of this object are expression proxies for the + non-dimension fields of the element's records. + """ + + def __init__(self, element: DimensionElement): + self._element = element + + def __repr__(self) -> str: + return self._element.name + + def __getattr__(self, field: str) -> ExpressionProxy: + expression = tree.DimensionFieldReference(element=self._element.name, field=field) + match field: + case "region": + return RegionProxy(expression) + case "timespan": + return TimespanProxy(expression) + return ScalarExpressionProxy(expression) + + def __dir__(self) -> list[str]: + result = list(super().__dir__()) + result.extend(self._element.RecordClass.fields.facts.names) + if self._element.spatial: + result.append("region") + if self._element.temporal: + result.append("temporal") + return result + + +class DimensionProxy(ScalarExpressionProxy, DimensionElementProxy): + """An expression-creation proxy for a dimension logical table. + + Parameters + ---------- + dimension : `DimensionElement` + Element this object wraps. + + Notes + ----- + This class combines record-field attribute access from `DimensionElement` + proxy with direct interpretation as a dimension key column via + `ScalarExpressionProxy`. For example:: + + x = query.expression_factory + query.where( + x.detector.purpose == "SCIENCE", # field access + x.detector > 100, # direct usage as an expression + ) + """ + + def __init__(self, dimension: DimensionElement): + ScalarExpressionProxy.__init__(self, tree.DimensionKeyReference(dimension=dimension.name)) + DimensionElementProxy.__init__(self, dimension) + + +class DatasetTypeProxy: + """An expression-creation proxy for a dataset type's logical table. + + Parameters + ---------- + dataset_type : `str` + Dataset type name or wildcard. Wildcards are usable only when the + query contains exactly one dataset type or a wildcard. + + Notes + ----- + The attributes of this object are expression proxies for the fields + associated with datasets rather than their dimensions. + """ + + def __init__(self, dataset_type: str): + self._dataset_type = dataset_type + + def __repr__(self) -> str: + return self._dataset_type + + # Attributes are actually fixed, but we implement them with __getattr__ + # and __dir__ to avoid repeating the list. And someday they might expand + # to include Datastore record fields. + + def __getattr__(self, field: str) -> ExpressionProxy: + if field not in tree.DATASET_FIELD_NAMES: + raise AttributeError(field) + expression = tree.DatasetFieldReference(dataset_type=self._dataset_type, field=field) + if field == "timespan": + return TimespanProxy(expression) + return ScalarExpressionProxy(expression) + + def __dir__(self) -> list[str]: + result = list(super().__dir__()) + result.extend(tree.DATASET_FIELD_NAMES) + return result + + +class ExpressionFactory: + """A factory for creating column expressions that uses operator overloading + to form a mini-language. + + Instances of this class are usually obtained from + `RelationQuery.expression_factory`; see that property's documentation for + more information. + + Parameters + ---------- + universe : `DimensionUniverse` + Object that describes all dimensions. + """ + + def __init__(self, universe: DimensionUniverse): + self._universe = universe + + def __getattr__(self, name: str) -> DimensionElementProxy: + element = self._universe.elements[name] + if element in self._universe.dimensions: + return DimensionProxy(element) + return DimensionElementProxy(element) + + def __getitem__(self, name: str) -> DatasetTypeProxy: + return DatasetTypeProxy(name) + + def not_(self, operand: tree.Predicate) -> tree.Predicate: + """Apply a logical NOT operation to a boolean expression. + + Parameters + ---------- + operand : `tree.Predicate` + Expression to invetree. + + Returns + ------- + logical_not : `tree.Predicate` + A boolean expression that evaluates to the opposite of ``operand``. + """ + return operand.logical_not() + + def all(self, first: tree.Predicate, /, *args: tree.Predicate) -> tree.Predicate: + """Combine a sequence of boolean expressions with logical AND. + + Parameters + ---------- + first : `tree.Predicate` + First operand (required). + *args + Additional operands. + + Returns + ------- + logical_and : `tree.Predicate` + A boolean expression that evaluates to `True` only if all operands + evaluate to `True. + """ + return first.logical_and(*args) + + def any(self, first: tree.Predicate, /, *args: tree.Predicate) -> tree.Predicate: + """Combine a sequence of boolean expressions with logical OR. + + Parameters + ---------- + first : `tree.Predicate` + First operand (required). + *args + Additional operands. + + Returns + ------- + logical_or : `tree.Predicate` + A boolean expression that evaluates to `True` if any operand + evaluates to `True. + """ + return first.logical_or(*args) + + @staticmethod + def literal(value: object) -> ExpressionProxy: + """Return an expression proxy that represents a literal value. + + Expression proxy objects obtained from this factory can generally be + compared directly to literals, so calling this method directly in user + code should rarely be necessary. + + Parameters + ---------- + value : `object` + Value to include as a literal in an expression tree. + + Returns + ------- + expression : `ExpressionProxy` + Expression wrapper for this literal. + """ + expression = tree.make_column_literal(value) + match expression.expression_type: + case "timespan": + return TimespanProxy(expression) + case "region": + return RegionProxy(expression) + case "bool": + raise NotImplementedError("Boolean literals are not supported.") + case _: + return ScalarExpressionProxy(expression) diff --git a/python/lsst/daf/butler/queries/overlaps.py b/python/lsst/daf/butler/queries/overlaps.py new file mode 100644 index 0000000000..7f92038cf6 --- /dev/null +++ b/python/lsst/daf/butler/queries/overlaps.py @@ -0,0 +1,466 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("OverlapsVisitor",) + +import itertools +from collections.abc import Hashable, Iterable, Sequence, Set +from typing import Generic, Literal, TypeVar, cast + +from lsst.sphgeom import Region + +from .._topology import TopologicalFamily +from ..dimensions import DimensionElement, DimensionGroup +from . import tree +from .visitors import PredicateVisitFlags, SimplePredicateVisitor + +_T = TypeVar("_T", bound=Hashable) + + +class _NaiveDisjointSet(Generic[_T]): + """A very naive (but simple) implementation of a "disjoint set" data + structure for strings, with mostly O(N) performance. + + This class should not be used in any context where the number of elements + in the data structure is large. It intentionally implements a subset of + the interface of `scipy.cluster.DisJointSet` so that non-naive + implementation could be swapped in if desired. + + Parameters + ---------- + superset : `~collections.abc.Iterable` [ `str` ] + Elements to initialize the disjoint set, with each in its own + single-element subset. + """ + + def __init__(self, superset: Iterable[_T]): + self._subsets = [{k} for k in superset] + self._subsets.sort(key=len, reverse=True) + + def add(self, k: _T) -> bool: # numpydoc ignore=PR04 + """Add a new element as its own single-element subset unless it is + already present. + + Parameters + ---------- + k + Value to add. + + Returns + ------- + added : `bool`: + `True` if the value was actually added, `False` if it was already + present. + """ + for subset in self._subsets: + if k in subset: + return False + self._subsets.append({k}) + return True + + def merge(self, a: _T, b: _T) -> bool: # numpydoc ignore=PR04 + """Merge the subsets containing the given elements. + + Parameters + ---------- + a : + Element whose subset should be merged. + b : + Element whose subset should be merged. + + Returns + ------- + merged : `bool` + `True` if a merge occurred, `False` if the elements were already in + the same subset. + """ + for i, subset in enumerate(self._subsets): + if a in subset: + break + else: + raise KeyError(f"Merge argument {a!r} not in disjoin set {self._subsets}.") + for j, subset in enumerate(self._subsets): + if b in subset: + break + else: + raise KeyError(f"Merge argument {b!r} not in disjoin set {self._subsets}.") + if i == j: + return False + i, j = sorted((i, j)) + self._subsets[i].update(self._subsets[j]) + del self._subsets[j] + self._subsets.sort(key=len, reverse=True) + return True + + def subsets(self) -> Sequence[Set[_T]]: + """Return the current subsets, ordered from largest to smallest.""" + return self._subsets + + @property + def n_subsets(self) -> int: + """The number of subsets.""" + return len(self._subsets) + + +class OverlapsVisitor(SimplePredicateVisitor): + """A helper class for dealing with spatial and temporal overlaps in a + query. + + Parameters + ---------- + dimensions : `DimensionGroup` + Dimensions of the query. + + Notes + ----- + This class includes logic for extracting explicit spatial and temporal + joins from a WHERE-clause predicate and computing automatic joins given the + dimensions of the query. It is designed to be subclassed by query driver + implementations that want to rewrite the predicate at the same time. + """ + + def __init__(self, dimensions: DimensionGroup): + self.dimensions = dimensions + self._spatial_connections = _NaiveDisjointSet(self.dimensions.spatial) + self._temporal_connections = _NaiveDisjointSet(self.dimensions.temporal) + + def run(self, predicate: tree.Predicate, join_operands: Iterable[DimensionGroup]) -> tree.Predicate: + """Process the given predicate to extract spatial and temporal + overlaps. + + Parameters + ---------- + predicate : `tree.Predicate` + Predicate to process. + join_operands : `~collections.abc.Iterable` [ `DimensionGroup` ] + The dimensions of logical tables being joined into this query; + these can included embedded spatial and temporal joins that can + make it unnecessary to add new ones. + + Returns + ------- + predicate : `tree.Predicate` + A possibly-modified predicate that should replace the original. + """ + result = predicate.visit(self) + if result is None: + result = predicate + for join_operand_dimensions in join_operands: + self.add_join_operand_connections(join_operand_dimensions) + for a, b in self.compute_automatic_spatial_joins(): + join_predicate = self.visit_spatial_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS) + if join_predicate is None: + join_predicate = tree.Predicate.compare( + tree.DimensionFieldReference.model_construct(element=a, field="region"), + "overlaps", + tree.DimensionFieldReference.model_construct(element=b, field="region"), + ) + result = result.logical_and(join_predicate) + for a, b in self.compute_automatic_temporal_joins(): + join_predicate = self.visit_temporal_dimension_join(a, b, PredicateVisitFlags.HAS_AND_SIBLINGS) + if join_predicate is None: + join_predicate = tree.Predicate.compare( + tree.DimensionFieldReference.model_construct(element=a, field="timespan"), + "overlaps", + tree.DimensionFieldReference.model_construct(element=b, field="timespan"), + ) + result = result.logical_and(join_predicate) + return result + + def visit_comparison( + self, + a: tree.ColumnExpression, + operator: tree.ComparisonOperator, + b: tree.ColumnExpression, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + if operator == "overlaps": + if a.column_type == "region": + return self.visit_spatial_overlap(a, b, flags) + elif b.column_type == "timespan": + return self.visit_temporal_overlap(a, b, flags) + else: + raise AssertionError(f"Unexpected column type {a.column_type} for overlap.") + return None + + def add_join_operand_connections(self, operand_dimensions: DimensionGroup) -> None: + """Add overlap connections implied by a table or subquery. + + Parameters + ---------- + operand_dimensions : `DimensionGroup` + Dimensions of of the table or subquery. + + Notes + ----- + We assume each join operand to a `tree.Select` has its own + complete set of spatial and temporal joins that went into generating + its rows. That will naturally be true for relations originating from + the butler database, like dataset searches and materializations, and if + it isn't true for a data ID upload, that would represent an intentional + association between non-overlapping things that we'd want to respect by + *not* adding a more restrictive automatic join. + """ + for a_family, b_family in itertools.pairwise(operand_dimensions.spatial): + self._spatial_connections.merge(a_family, b_family) + for a_family, b_family in itertools.pairwise(operand_dimensions.temporal): + self._temporal_connections.merge(a_family, b_family) + + def compute_automatic_spatial_joins(self) -> list[tuple[DimensionElement, DimensionElement]]: + """Return pairs of dimension elements that should be spatially joined. + + Returns + ------- + joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ] + Automatic joins. + + Notes + ----- + This takes into account explicit joins extracted by `process` and + implicit joins added by `add_join_operand_connections`, and only + returns additional joins if there is an unambiguous way to spatially + connect any dimensions that are not already spatially connected. + Automatic joins are always the most fine-grained join between sets of + dimensions (i.e. ``visit_detector_region`` and ``patch`` instead of + ``visit`` and ``tract``), but explicitly adding a coarser join between + sets of elements will prevent the fine-grained join from being added. + """ + return self._compute_automatic_joins("spatial", self._spatial_connections) + + def compute_automatic_temporal_joins(self) -> list[tuple[DimensionElement, DimensionElement]]: + """Return pairs of dimension elements that should be spatially joined. + + Returns + ------- + joins : `list` [ `tuple` [ `DimensionElement`, `DimensionElement` ] ] + Automatic joins. + + Notes + ----- + See `compute_automatic_spatial_joins` for information on how automatic + joins are determined. Joins to dataset validity ranges are never + automatic. + """ + return self._compute_automatic_joins("temporal", self._temporal_connections) + + def _compute_automatic_joins( + self, kind: Literal["spatial", "temporal"], connections: _NaiveDisjointSet[TopologicalFamily] + ) -> list[tuple[DimensionElement, DimensionElement]]: + if connections.n_subsets == 1: + # All of the joins we need are already present. + return [] + if connections.n_subsets > 2: + raise tree.InvalidQueryTreeError( + f"Too many disconnected sets of {kind} families for an automatic " + f"join: {connections.subsets()}. Add explicit {kind} joins to avoid this error." + ) + a_subset, b_subset = connections.subsets() + if len(a_subset) > 1 or len(b_subset) > 1: + raise tree.InvalidQueryTreeError( + f"A {kind} join is needed between {a_subset} and {b_subset}, but which join to " + "add is ambiguous. Add an explicit spatial join to avoid this error." + ) + # We have a pair of families that are not explicitly or implicitly + # connected to any other families; add an automatic join between their + # most fine-grained members. + (a_family,) = a_subset + (b_family,) = b_subset + return [ + ( + cast(DimensionElement, a_family.choose(self.dimensions.elements, self.dimensions.universe)), + cast(DimensionElement, b_family.choose(self.dimensions.elements, self.dimensions.universe)), + ) + ] + + def visit_spatial_overlap( + self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + """Dispatch a spatial overlap comparison predicate to handlers. + + This method should rarely (if ever) need to be overridden. + + Parameters + ---------- + a : `tree.ColumnExpression` + First operand. + b : `tree.ColumnExpression` + Second operand. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + match a, b: + case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference( + element=b_element + ): + return self.visit_spatial_join(a_element, b_element, flags) + case tree.DimensionFieldReference(element=element), region_expression: + pass + case region_expression, tree.DimensionFieldReference(element=element): + pass + case _: + raise AssertionError(f"Unexpected arguments for spatial overlap: {a}, {b}.") + if region := region_expression.get_literal_value(): + raise AssertionError(f"Unexpected argument for spatial overlap: {region_expression}.") + return self.visit_spatial_constraint(element, region, flags) + + def visit_temporal_overlap( + self, a: tree.ColumnExpression, b: tree.ColumnExpression, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + """Dispatch a temporal overlap comparison predicate to handlers. + + This method should rarely (if ever) need to be overridden. + + Parameters + ---------- + a : `tree.ColumnExpression`- + First operand. + b : `tree.ColumnExpression` + Second operand. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + match a, b: + case tree.DimensionFieldReference(element=a_element), tree.DimensionFieldReference( + element=b_element + ): + return self.visit_temporal_dimension_join(a_element, b_element, flags) + case _: + # We don't bother differentiating any other kind of temporal + # comparison, because in all foreseeable database schemas we + # wouldn't have to do anything special with them, since they + # don't participate in automatic join calculations and they + # should be straightforwardly convertible to SQL. + return None + + def visit_spatial_join( + self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + """Handle a spatial overlap comparison between two dimension elements. + + The default implementation updates the set of known spatial connections + (for use by `compute_automatic_spatial_joins`) and returns `None`. + + Parameters + ---------- + a : `DimensionElement` + One element in the join. + b : `DimensionElement` + The other element in the join. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + if a.spatial == b.spatial: + raise tree.InvalidQueryTreeError(f"Spatial join between {a} and {b} is not necessary.") + self._spatial_connections.merge( + cast(TopologicalFamily, a.spatial), cast(TopologicalFamily, b.spatial) + ) + return None + + def visit_spatial_constraint( + self, + element: DimensionElement, + region: Region, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + """Handle a spatial overlap comparison between a dimension element and + a literal region. + + The default implementation just returns `None`. + + Parameters + ---------- + element : `DimensionElement` + The dimension element in the comparison. + region : `lsst.sphgeom.Region` + The literal region in the comparison. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + return None + + def visit_temporal_dimension_join( + self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + """Handle a temporal overlap comparison between two dimension elements. + + The default implementation updates the set of known temporal + connections (for use by `compute_automatic_temporal_joins`) and returns + `None`. + + Parameters + ---------- + a : `DimensionElement` + One element in the join. + b : `DimensionElement` + The other element in the join. + flags : `tree.PredicateLeafFlags` + Information about where this overlap comparison appears in the + larger predicate tree. + + Returns + ------- + replaced : `tree.Predicate` or `None` + The predicate to be inserted instead in the processed tree, or + `None` if no substitution is needed. + """ + if a.temporal == b.temporal: + raise tree.InvalidQueryTreeError(f"Temporal join between {a} and {b} is not necessary.") + self._temporal_connections.merge( + cast(TopologicalFamily, a.temporal), cast(TopologicalFamily, b.temporal) + ) + return None diff --git a/python/lsst/daf/butler/queries/result_specs.py b/python/lsst/daf/butler/queries/result_specs.py new file mode 100644 index 0000000000..df5e07423e --- /dev/null +++ b/python/lsst/daf/butler/queries/result_specs.py @@ -0,0 +1,232 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "ResultSpecBase", + "DataCoordinateResultSpec", + "DimensionRecordResultSpec", + "DatasetRefResultSpec", +) + +from collections.abc import Mapping +from typing import Annotated, Literal, TypeAlias, Union, cast + +import pydantic + +from ..dimensions import DimensionElement, DimensionGroup +from .tree import ColumnSet, DatasetFieldName, InvalidQueryTreeError, OrderExpression, QueryTree + + +class ResultSpecBase(pydantic.BaseModel): + """Base class for all query-result specification objects.""" + + order_by: tuple[OrderExpression, ...] = () + """Expressions to sort the rows by.""" + + offset: int = 0 + """Index of the first row to return.""" + + limit: int | None = None + """Maximum number of rows to return, or `None` for no bound.""" + + def validate_tree(self, tree: QueryTree) -> None: + """Check that this result object is consistent with a query tree. + + Parameters + ---------- + tree : `QueryTree` + Query tree that defines the joins and row-filtering that these + results will come from. + """ + spec = cast(ResultSpec, self) + if not spec.dimensions <= tree.dimensions: + raise InvalidQueryTreeError( + f"Query result specification has dimensions {spec.dimensions} that are not a subset of the " + f"query's dimensions {tree.dimensions}." + ) + result_columns = spec.get_result_columns() + assert result_columns.dimensions == spec.dimensions, "enforced by ResultSpec implementations" + for dataset_type in result_columns.dataset_fields: + if dataset_type not in tree.datasets: + raise InvalidQueryTreeError(f"Dataset {dataset_type!r} is not available from this query.") + if not (tree.datasets[dataset_type].dimensions <= spec.dimensions): + raise InvalidQueryTreeError( + f"Result dataset type {dataset_type!r} has dimensions " + f"{tree.datasets[dataset_type].dimensions} that are not a subset of the result " + f"dimensions {spec.dimensions}." + ) + order_by_columns = ColumnSet(spec.dimensions) + for term in spec.order_by: + term.gather_required_columns(order_by_columns) + if not (order_by_columns.dimensions <= spec.dimensions): + raise InvalidQueryTreeError( + "Order-by expression may not reference columns that are not in the result dimensions." + ) + for dataset_type in order_by_columns.dataset_fields.keys(): + if dataset_type not in tree.datasets: + raise InvalidQueryTreeError( + f"Dataset type {dataset_type!r} in order-by expression is not part of the query." + ) + if not (tree.datasets[dataset_type].dimensions <= spec.dimensions): + raise InvalidQueryTreeError( + f"Dataset type {dataset_type!r} in order-by expression has dimensions " + f"{tree.datasets[dataset_type].dimensions} that are not a subset of the result " + f"dimensions {spec.dimensions}." + ) + + @property + def find_first_dataset(self) -> str | None: + return None + + +class DataCoordinateResultSpec(ResultSpecBase): + """Specification for a query that yields `DataCoordinate` objects.""" + + result_type: Literal["data_coordinate"] = "data_coordinate" + dimensions: DimensionGroup + include_dimension_records: bool + + def get_result_columns(self) -> ColumnSet: + """Return the columns included in the actual result rows. + + This does not necessarily include all columns required by the + `order_by` terms that are also a part of this spec. + """ + result = ColumnSet(self.dimensions) + if self.include_dimension_records: + for element_name in self.dimensions.elements: + element = self.dimensions.universe[element_name] + if not element.is_cached: + result.dimension_fields[element_name].update(element.schema.remainder.names) + return result + + +class DimensionRecordResultSpec(ResultSpecBase): + """Specification for a query that yields `DimensionRecord` objects.""" + + result_type: Literal["dimension_record"] = "dimension_record" + element: DimensionElement + + @property + def dimensions(self) -> DimensionGroup: + return self.element.minimal_group + + def get_result_columns(self) -> ColumnSet: + """Return the columns included in the actual result rows. + + This does not necessarily include all columns required by the + `order_by` terms that are also a part of this spec. + """ + result = ColumnSet(self.element.minimal_group) + result.dimension_fields[self.element.name].update(self.element.schema.remainder.names) + return result + + +class DatasetRefResultSpec(ResultSpecBase): + """Specification for a query that yields `DatasetRef` objects.""" + + result_type: Literal["dataset_ref"] = "dataset_ref" + dataset_type_name: str + dimensions: DimensionGroup + storage_class_name: str + include_dimension_records: bool + find_first: bool + + @property + def find_first_dataset(self) -> str | None: + return self.dataset_type_name if self.find_first else None + + def get_result_columns(self) -> ColumnSet: + """Return the columns included in the actual result rows. + + This does not necessarily include all columns required by the + `order_by` terms that are also a part of this spec. + """ + result = ColumnSet(self.dimensions) + result.dataset_fields[self.dataset_type_name].update({"dataset_id", "run"}) + if self.include_dimension_records: + for element_name in self.dimensions.elements: + element = self.dimensions.universe[element_name] + if not element.is_cached: + result.dimension_fields[element_name].update(element.schema.remainder.names) + return result + + +class GeneralResultSpec(ResultSpecBase): + """Specification for a query that yields a table with + an explicit list of columns. + """ + + result_type: Literal["general"] = "general" + dimensions: DimensionGroup + dimension_fields: Mapping[str, set[str]] + dataset_fields: Mapping[str, set[DatasetFieldName]] + find_first: bool + + @property + def find_first_dataset(self) -> str | None: + if self.find_first: + (dataset_type,) = self.dataset_fields.keys() + return dataset_type + return None + + def get_result_columns(self) -> ColumnSet: + """Return the columns included in the actual result rows. + + This does not necessarily include all columns required by the + `order_by` terms that are also a part of this spec. + """ + result = ColumnSet(self.dimensions) + for element_name, fields_for_element in self.dimension_fields.items(): + result.dimension_fields[element_name].update(fields_for_element) + for dataset_type, fields_for_dataset in self.dataset_fields.items(): + result.dataset_fields[dataset_type].update(fields_for_dataset) + return result + + @pydantic.model_validator(mode="after") + def _validate(self) -> GeneralResultSpec: + if self.find_first and len(self.dataset_fields) != 1: + raise InvalidQueryTreeError("find_first=True requires exactly one result dataset type.") + for element_name, fields_for_element in self.dimension_fields.items(): + if element_name not in self.dimensions.elements: + raise InvalidQueryTreeError(f"Dimension element {element_name} is not in {self.dimensions}.") + if not fields_for_element: + raise InvalidQueryTreeError( + f"Empty dimension element field set for {element_name!r} is not permitted." + ) + for dataset_type, fields_for_dataset in self.dataset_fields.items(): + if not fields_for_dataset: + raise InvalidQueryTreeError(f"Empty dataset field set for {dataset_type!r} is not permitted.") + return self + + +ResultSpec: TypeAlias = Annotated[ + Union[DataCoordinateResultSpec, DimensionRecordResultSpec, DatasetRefResultSpec, GeneralResultSpec], + pydantic.Field(discriminator="result_type"), +] diff --git a/python/lsst/daf/butler/queries/tree/__init__.py b/python/lsst/daf/butler/queries/tree/__init__.py new file mode 100644 index 0000000000..e320695f62 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/__init__.py @@ -0,0 +1,40 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from ._base import * +from ._column_expression import * +from ._column_literal import * +from ._column_reference import * +from ._column_set import * +from ._predicate import * +from ._predicate import LogicalNot +from ._query_tree import * + +LogicalNot.model_rebuild() +del LogicalNot + +Predicate.model_rebuild() diff --git a/python/lsst/daf/butler/queries/tree/_base.py b/python/lsst/daf/butler/queries/tree/_base.py new file mode 100644 index 0000000000..e07de32b25 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_base.py @@ -0,0 +1,252 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "QueryTreeBase", + "ColumnExpressionBase", + "DatasetFieldName", + "InvalidQueryTreeError", + "DATASET_FIELD_NAMES", +) + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeAlias, TypeVar, cast, get_args + +import pydantic + +if TYPE_CHECKING: + from ...column_spec import ColumnType + from ..visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor + from ._column_literal import ColumnLiteral + from ._column_set import ColumnSet + from ._predicate import PredicateLeaf + + +DatasetFieldName: TypeAlias = Literal["dataset_id", "ingest_date", "run", "collection", "timespan"] + +DATASET_FIELD_NAMES: tuple[DatasetFieldName, ...] = tuple(get_args(DatasetFieldName)) + +_T = TypeVar("_T") +_L = TypeVar("_L") +_A = TypeVar("_A") +_O = TypeVar("_O") + + +class InvalidQueryTreeError(RuntimeError): + """Exception raised when a query tree is or would not be valid.""" + + +class QueryTreeBase(pydantic.BaseModel): + """Base class for all non-primitive types in a query tree.""" + + model_config = pydantic.ConfigDict(frozen=True, extra="forbid", strict=True) + + +class ColumnExpressionBase(QueryTreeBase, ABC): + """Base class for objects that represent non-boolean column expressions in + a query tree. + + Notes + ----- + This is a closed hierarchy whose concrete, `~typing.final` derived classes + are members of the `ColumnExpression` union. That union should generally + be used in type annotations rather than the technically-open base class. + """ + + expression_type: str + + is_literal: ClassVar[bool] = False + """Whether this expression wraps a literal Python value.""" + + @property + @abstractmethod + def precedence(self) -> int: + """Operator precedence for this operation. + + Lower values bind more tightly, so parentheses are needed when printing + an expression where an operand has a higher value than the expression + itself. + """ + raise NotImplementedError() + + @property + @abstractmethod + def column_type(self) -> ColumnType: + """A string enumeration value representing the type of the column + expression. + """ + raise NotImplementedError() + + def get_literal_value(self) -> Any | None: + """Return the literal value wrapped by this expression, or `None` if + it is not a literal. + """ + return None + + @abstractmethod + def gather_required_columns(self, columns: ColumnSet) -> None: + """Add any columns required to evaluate this expression to the + given column set. + + Parameters + ---------- + columns : `ColumnSet` + Set of columns to modify in place. + """ + raise NotImplementedError() + + @abstractmethod + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + """Invoke the visitor interface. + + Parameters + ---------- + visitor : `ColumnExpressionVisitor` + Visitor to invoke a method on. + + Returns + ------- + result : `object` + Forwarded result from the visitor. + """ + raise NotImplementedError() + + +class ColumnLiteralBase(ColumnExpressionBase): + """Base class for objects that represent literal values as column + expressions in a query tree. + + Notes + ----- + This is a closed hierarchy whose concrete, `~typing.final` derived classes + are members of the `ColumnLiteral` union. That union should generally be + used in type annotations rather than the technically-open base class. The + concrete members of that union are only semi-public; they appear in the + serialized form of a column expression tree, but should only be constructed + via the `make_column_literal` factory function. All concrete members of + the union are also guaranteed to have a read-only ``value`` attribute + holding the wrapped literal, but it is unspecified whether that is a + regular attribute or a `property`. + """ + + is_literal: ClassVar[bool] = True + """Whether this expression wraps a literal Python value.""" + + @property + def precedence(self) -> int: + # Docstring inherited. + return 0 + + def get_literal_value(self) -> Any: + # Docstring inherited. + return cast("ColumnLiteral", self).value + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + pass + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return cast(ColumnType, self.expression_type) + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited + return visitor.visit_literal(cast("ColumnLiteral", self)) + + +class PredicateLeafBase(QueryTreeBase, ABC): + """Base class for leaf nodes of the `Predicate` tree. + + Notes + ----- + This is a closed hierarchy whose concrete, `~typing.final` derived classes + are members of the `PredicateLeaf` union. That union should generally be + used in type annotations rather than the technically-open base class. The + concrete members of that union are only semi-public; they appear in the + serialized form of a `Predicate`, but should only be constructed + via various `Predicate` factory methods. + """ + + @property + @abstractmethod + def precedence(self) -> int: + """Operator precedence for this operation. + + Lower values bind more tightly, so parentheses are needed when printing + an expression where an operand has a higher value than the expression + itself. + """ + raise NotImplementedError() + + @property + def column_type(self) -> Literal["bool"]: + """A string enumeration value representing the type of the column + expression. + """ + return "bool" + + @abstractmethod + def gather_required_columns(self, columns: ColumnSet) -> None: + """Add any columns required to evaluate this predicate leaf to the + given column set. + + Parameters + ---------- + columns : `ColumnSet` + Set of columns to modify in place. + """ + raise NotImplementedError() + + def invert(self) -> PredicateLeaf: + """Return a new leaf that is the logical not of this one.""" + from ._predicate import LogicalNot, LogicalNotOperand + + # This implementation works for every subclass other than LogicalNot + # itself, which overrides this method. + return LogicalNot.model_construct(operand=cast(LogicalNotOperand, self)) + + @abstractmethod + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + """Invoke the visitor interface. + + Parameters + ---------- + visitor : `PredicateVisitor` + Visitor to invoke a method on. + flags : `PredicateVisitFlags` + Flags that provide information about where this leaf appears in the + larger predicate tree. + + Returns + ------- + result : `object` + Forwarded result from the visitor. + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/queries/tree/_column_expression.py b/python/lsst/daf/butler/queries/tree/_column_expression.py new file mode 100644 index 0000000000..d792245aec --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_column_expression.py @@ -0,0 +1,257 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "ColumnExpression", + "OrderExpression", + "UnaryExpression", + "BinaryExpression", + "Reversed", + "UnaryOperator", + "BinaryOperator", +) + +from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, TypeVar, Union, final + +import pydantic + +from ...column_spec import ColumnType +from ._base import ColumnExpressionBase, InvalidQueryTreeError +from ._column_literal import ColumnLiteral +from ._column_reference import _ColumnReference +from ._column_set import ColumnSet + +if TYPE_CHECKING: + from ..visitors import ColumnExpressionVisitor + + +_T = TypeVar("_T") + + +UnaryOperator: TypeAlias = Literal["-", "begin_of", "end_of"] +BinaryOperator: TypeAlias = Literal["+", "-", "*", "/", "%"] + + +@final +class UnaryExpression(ColumnExpressionBase): + """A unary operation on a column expression that returns a non-bool.""" + + expression_type: Literal["unary"] = "unary" + + operand: ColumnExpression + """Expression this one operates on.""" + + operator: UnaryOperator + """Operator this expression applies.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.operand.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 1 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + match self.operator: + case "-": + return self.operand.column_type + case "begin_of" | "end_of": + return "datetime" + raise AssertionError(f"Invalid unary expression operator {self.operator}.") + + def __str__(self) -> str: + s = str(self.operand) + if self.operand.precedence >= self.precedence: + s = f"({s})" + match self.operator: + case "-": + return f"-{s}" + case "begin_of": + return f"{s}.begin" + case "end_of": + return f"{s}.end" + + @pydantic.model_validator(mode="after") + def _validate_types(self) -> UnaryExpression: + match (self.operator, self.operand.column_type): + case "-" "int" | "float": + pass + case ("begin_of" | "end_of", "timespan"): + pass + case _: + raise InvalidQueryTreeError( + f"Invalid column type {self.operand.column_type} for operator {self.operator!r}." + ) + return self + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_unary_expression(self) + + +@final +class BinaryExpression(ColumnExpressionBase): + """A binary operation on column expressions that returns a non-bool.""" + + expression_type: Literal["binary"] = "binary" + + a: ColumnExpression + """Left-hand side expression this one operates on.""" + + b: ColumnExpression + """Right-hand side expression this one operates on.""" + + operator: BinaryOperator + """Operator this expression applies. + + Integer '/' and '%' are defined as in SQL, not Python (though the + definitions are the same for positive arguments). + """ + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.a.gather_required_columns(columns) + self.b.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + match self.operator: + case "*" | "/" | "%": + return 2 + case "+" | "-": + return 3 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return self.a.column_type + + def __str__(self) -> str: + a = str(self.a) + b = str(self.b) + match self.operator: + case "*" | "+": + if self.a.precedence > self.precedence: + a = f"({a})" + if self.b.precedence > self.precedence: + b = f"({b})" + case _: + if self.a.precedence >= self.precedence: + a = f"({a})" + if self.b.precedence >= self.precedence: + b = f"({b})" + return f"({a} {self.operator} {b})" + + @pydantic.model_validator(mode="after") + def _validate_types(self) -> BinaryExpression: + if self.a.column_type != self.b.column_type: + raise InvalidQueryTreeError( + f"Column types for operator {self.operator} do not agree " + f"({self.a.column_type}, {self.b.column_type})." + ) + match (self.operator, self.a.column_type): + case ("+" | "-" | "*" | "/", "int" | "float"): + pass + case ("%", "int"): + pass + case _: + raise InvalidQueryTreeError( + f"Invalid column type {self.a.column_type} for operator {self.operator!r}." + ) + return self + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_binary_expression(self) + + +# Union without Pydantic annotation for the discriminator, for use in nesting +# in other unions that will add that annotation. It's not clear whether it +# would work to just nest the annotated ones, but it seems safest not to rely +# on undocumented behavior. +_ColumnExpression: TypeAlias = Union[ + ColumnLiteral, + _ColumnReference, + UnaryExpression, + BinaryExpression, +] + + +ColumnExpression: TypeAlias = Annotated[_ColumnExpression, pydantic.Field(discriminator="expression_type")] + + +@final +class Reversed(ColumnExpressionBase): + """A tag wrapper for `AbstractExpression` that indicate sorting in + reverse order. + """ + + expression_type: Literal["reversed"] = "reversed" + + operand: ColumnExpression + """Expression to sort on in reverse.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.operand.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return self.operand.precedence + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return self.operand.column_type + + def __str__(self) -> str: + return f"{self.operand} DESC" + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_reversed(self) + + +def _validate_order_expression(expression: _ColumnExpression | Reversed) -> _ColumnExpression | Reversed: + if expression.column_type not in ("int", "string", "float", "datetime"): + raise InvalidQueryTreeError(f"Column type {expression.column_type} of {expression} is not ordered.") + return expression + + +OrderExpression: TypeAlias = Annotated[ + Union[_ColumnExpression, Reversed], + pydantic.Field(discriminator="expression_type"), + pydantic.AfterValidator(_validate_order_expression), +] diff --git a/python/lsst/daf/butler/queries/tree/_column_literal.py b/python/lsst/daf/butler/queries/tree/_column_literal.py new file mode 100644 index 0000000000..17ef812b14 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_column_literal.py @@ -0,0 +1,372 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "ColumnLiteral", + "make_column_literal", +) + +import uuid +import warnings +from base64 import b64decode, b64encode +from functools import cached_property +from typing import Literal, TypeAlias, Union, final + +import astropy.time +import erfa +from lsst.sphgeom import Region + +from ..._timespan import Timespan +from ...time_utils import TimeConverter +from ._base import ColumnLiteralBase + +LiteralValue: TypeAlias = Union[int, str, float, bytes, uuid.UUID, astropy.time.Time, Timespan, Region] + + +@final +class IntColumnLiteral(ColumnLiteralBase): + """A literal `int` value in a column expression.""" + + expression_type: Literal["int"] = "int" + + value: int + """The wrapped value after base64 encoding.""" + + @classmethod + def from_value(cls, value: int) -> IntColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `int` + Value to wrap. + + Returns + ------- + expression : `IntColumnLiteral` + Literal expression object. + """ + return cls.model_construct(value=value) + + def __str__(self) -> str: + return repr(self.value) + + +@final +class StringColumnLiteral(ColumnLiteralBase): + """A literal `str` value in a column expression.""" + + expression_type: Literal["string"] = "string" + + value: str + """The wrapped value after base64 encoding.""" + + @classmethod + def from_value(cls, value: str) -> StringColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `str` + Value to wrap. + + Returns + ------- + expression : `StrColumnLiteral` + Literal expression object. + """ + return cls.model_construct(value=value) + + def __str__(self) -> str: + return repr(self.value) + + +@final +class FloatColumnLiteral(ColumnLiteralBase): + """A literal `float` value in a column expression.""" + + expression_type: Literal["float"] = "float" + + value: float + """The wrapped value after base64 encoding.""" + + @classmethod + def from_value(cls, value: float) -> FloatColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `float` + Value to wrap. + + Returns + ------- + expression : `FloatColumnLiteral` + Literal expression object. + """ + return cls.model_construct(value=value) + + def __str__(self) -> str: + return repr(self.value) + + +@final +class HashColumnLiteral(ColumnLiteralBase): + """A literal `bytes` value representing a hash in a column expression. + + The original value is base64-encoded when serialized and decoded on first + use. + """ + + expression_type: Literal["hash"] = "hash" + + encoded: bytes + """The wrapped value after base64 encoding.""" + + @cached_property + def value(self) -> bytes: + """The wrapped value.""" + return b64decode(self.encoded) + + @classmethod + def from_value(cls, value: bytes) -> HashColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `bytes` + Value to wrap. + + Returns + ------- + expression : `HashColumnLiteral` + Literal expression object. + """ + return cls.model_construct(encoded=b64encode(value)) + + def __str__(self) -> str: + return "(bytes)" + + +@final +class UUIDColumnLiteral(ColumnLiteralBase): + """A literal `uuid.UUID` value in a column expression.""" + + expression_type: Literal["uuid"] = "uuid" + + value: uuid.UUID + + @classmethod + def from_value(cls, value: uuid.UUID) -> UUIDColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `uuid.UUID` + Value to wrap. + + Returns + ------- + expression : `UUIDColumnLiteral` + Literal expression object. + """ + return cls.model_construct(value=value) + + def __str__(self) -> str: + return str(self.value) + + +@final +class DateTimeColumnLiteral(ColumnLiteralBase): + """A literal `astropy.time.Time` value in a column expression. + + The time is converted into TAI nanoseconds since 1970-01-01 when serialized + and restored from that on first use. + """ + + expression_type: Literal["datetime"] = "datetime" + + nsec: int + """TAI nanoseconds since 1970-01-01.""" + + @cached_property + def value(self) -> astropy.time.Time: + """The wrapped value.""" + return TimeConverter().nsec_to_astropy(self.nsec) + + @classmethod + def from_value(cls, value: astropy.time.Time) -> DateTimeColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `astropy.time.Time` + Value to wrap. + + Returns + ------- + expression : `DateTimeColumnLiteral` + Literal expression object. + """ + return cls.model_construct(nsec=TimeConverter().astropy_to_nsec(value)) + + def __str__(self) -> str: + # Trap dubious year warnings in case we have timespans from + # simulated data in the future + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=erfa.ErfaWarning) + return self.value.tai.strftime("%Y-%m-%dT%H:%M:%S") + + +@final +class TimespanColumnLiteral(ColumnLiteralBase): + """A literal `Timespan` value in a column expression. + + The timespan bounds are converted into TAI nanoseconds since 1970-01-01 + when serialized and the timespan is restored from that on first use. + """ + + expression_type: Literal["timespan"] = "timespan" + + begin_nsec: int + """TAI nanoseconds since 1970-01-01 for the lower bound of the timespan + (inclusive). + """ + + end_nsec: int + """TAI nanoseconds since 1970-01-01 for the upper bound of the timespan + (exclusive). + """ + + @cached_property + def value(self) -> astropy.time.Time: + """The wrapped value.""" + return Timespan(None, None, _nsec=(self.begin_nsec, self.end_nsec)) + + @classmethod + def from_value(cls, value: Timespan) -> TimespanColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `..Timespan` + Value to wrap. + + Returns + ------- + expression : `TimespanColumnLiteral` + Literal expression object. + """ + return cls.model_construct(begin_nsec=value._nsec[0], end_nsec=value._nsec[1]) + + def __str__(self) -> str: + return str(self.value) + + +@final +class RegionColumnLiteral(ColumnLiteralBase): + """A literal `lsst.sphgeom.Region` value in a column expression. + + The region is encoded to base64 `bytes` when serialized, and decoded on + first use. + """ + + expression_type: Literal["region"] = "region" + + encoded: bytes + """The wrapped value after base64 encoding.""" + + @cached_property + def value(self) -> bytes: + """The wrapped value.""" + return Region.decode(b64decode(self.encoded)) + + @classmethod + def from_value(cls, value: Region) -> RegionColumnLiteral: + """Construct from the wrapped value. + + Parameters + ---------- + value : `..Region` + Value to wrap. + + Returns + ------- + expression : `RegionColumnLiteral` + Literal expression object. + """ + return cls.model_construct(encoded=b64encode(value.encode())) + + def __str__(self) -> str: + return "(region)" + + +ColumnLiteral: TypeAlias = Union[ + IntColumnLiteral, + StringColumnLiteral, + FloatColumnLiteral, + HashColumnLiteral, + UUIDColumnLiteral, + DateTimeColumnLiteral, + TimespanColumnLiteral, + RegionColumnLiteral, +] + + +def make_column_literal(value: LiteralValue) -> ColumnLiteral: + """Construct a `ColumnLiteral` from the value it will wrap. + + Parameters + ---------- + value : `LiteralValue` + Value to wrap. + + Returns + ------- + expression : `ColumnLiteral` + Literal expression object. + """ + match value: + case int(): + return IntColumnLiteral.from_value(value) + case str(): + return StringColumnLiteral.from_value(value) + case float(): + return FloatColumnLiteral.from_value(value) + case uuid.UUID(): + return UUIDColumnLiteral.from_value(value) + case bytes(): + return HashColumnLiteral.from_value(value) + case astropy.time.Time(): + return DateTimeColumnLiteral.from_value(value) + case Timespan(): + return TimespanColumnLiteral.from_value(value) + case Region(): + return RegionColumnLiteral.from_value(value) + raise TypeError(f"Invalid type {type(value).__name__} of value {value!r} for column literal.") diff --git a/python/lsst/daf/butler/queries/tree/_column_reference.py b/python/lsst/daf/butler/queries/tree/_column_reference.py new file mode 100644 index 0000000000..41434a6bcd --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_column_reference.py @@ -0,0 +1,179 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("ColumnReference", "DimensionKeyReference", "DimensionFieldReference", "DatasetFieldReference") + +from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, TypeVar, Union, final + +import pydantic + +from ...column_spec import ColumnType +from ...dimensions import Dimension, DimensionElement +from ._base import ColumnExpressionBase, DatasetFieldName, InvalidQueryTreeError + +if TYPE_CHECKING: + from ..visitors import ColumnExpressionVisitor + from ._column_set import ColumnSet + + +_T = TypeVar("_T") + + +@final +class DimensionKeyReference(ColumnExpressionBase): + """A column expression that references a dimension primary key column.""" + + expression_type: Literal["dimension_key"] = "dimension_key" + + dimension: Dimension + """Definition of this dimension.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + columns.update_dimensions(self.dimension.minimal_group) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 0 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return self.dimension.primary_key.type + + def __str__(self) -> str: + return self.dimension.name + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_dimension_key_reference(self) + + +@final +class DimensionFieldReference(ColumnExpressionBase): + """A column expression that references a dimension record column that is + not a primary key. + """ + + expression_type: Literal["dimension_field"] = "dimension_field" + + element: DimensionElement + """Definition of the dimension element.""" + + field: str + """Name of the field (i.e. column) in the element's logical table.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + columns.update_dimensions(self.element.minimal_group) + columns.dimension_fields[self.element.name].add(self.field) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 0 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + return self.element.schema.remainder[self.field].type + + def __str__(self) -> str: + return f"{self.element}.{self.field}" + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_dimension_field_reference(self) + + @pydantic.model_validator(mode="after") + def _validate_field(self) -> DimensionFieldReference: + if self.field not in self.element.schema.remainder.names: + raise InvalidQueryTreeError(f"Dimension field {self.element.name}.{self.field} does not exist.") + return self + + +@final +class DatasetFieldReference(ColumnExpressionBase): + """A column expression that references a column associated with a dataset + type. + """ + + expression_type: Literal["dataset_field"] = "dataset_field" + + dataset_type: str + """Name of the dataset type to match any dataset type.""" + + field: DatasetFieldName + """Name of the field (i.e. column) in the dataset's logical table.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + columns.dataset_fields[self.dataset_type].add(self.field) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 0 + + @property + def column_type(self) -> ColumnType: + # Docstring inherited. + match self.field: + case "dataset_id": + return "uuid" + case "ingest_date": + return "datetime" + case "run": + return "string" + case "collection": + return "string" + case "timespan": + return "timespan" + raise AssertionError(f"Invalid field {self.field!r} for dataset.") + + def __str__(self) -> str: + return f"{self.dataset_type}.{self.field}" + + def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T: + # Docstring inherited. + return visitor.visit_dataset_field_reference(self) + + +# Union without Pydantic annotation for the discriminator, for use in nesting +# in other unions that will add that annotation. It's not clear whether it +# would work to just nest the annotated ones, but it seems safest not to rely +# on undocumented behavior. +_ColumnReference: TypeAlias = Union[ + DimensionKeyReference, + DimensionFieldReference, + DatasetFieldReference, +] + +ColumnReference: TypeAlias = Annotated[_ColumnReference, pydantic.Field(discriminator="expression_type")] diff --git a/python/lsst/daf/butler/queries/tree/_column_set.py b/python/lsst/daf/butler/queries/tree/_column_set.py new file mode 100644 index 0000000000..582b68af28 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_column_set.py @@ -0,0 +1,186 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("ColumnSet",) + +from collections.abc import Iterable, Iterator, Mapping +from typing import Literal + +from ... import column_spec +from ...dimensions import DimensionGroup +from ...nonempty_mapping import NonemptyMapping +from ._base import DATASET_FIELD_NAMES, DatasetFieldName + + +class ColumnSet: + def __init__(self, dimensions: DimensionGroup) -> None: + self._dimensions = dimensions + self._removed_dimension_keys: set[str] = set() + self._dimension_fields: dict[str, set[str]] = {name: set() for name in dimensions.elements} + self._dataset_fields = NonemptyMapping[str, set[DatasetFieldName | Literal["collection_key"]]](set) + + @property + def dimensions(self) -> DimensionGroup: + return self._dimensions + + @property + def dimension_fields(self) -> Mapping[str, set[str]]: + return self._dimension_fields + + @property + def dataset_fields(self) -> Mapping[str, set[DatasetFieldName | Literal["collection_key"]]]: + return self._dataset_fields + + def __bool__(self) -> bool: + return bool(self._dimensions) or any(self._dataset_fields.values()) + + def issubset(self, other: ColumnSet) -> bool: + return ( + self._dimensions.issubset(other._dimensions) + and all( + fields.issubset(other._dimension_fields[element_name]) + for element_name, fields in self._dimension_fields.items() + ) + and all( + fields.issubset(other._dataset_fields.get(dataset_type, frozenset())) + for dataset_type, fields in self._dataset_fields.items() + ) + ) + + def issuperset(self, other: ColumnSet) -> bool: + return other.issubset(self) + + def isdisjoint(self, other: ColumnSet) -> bool: + # Note that if the dimensions are disjoint, the dimension fields are + # also disjoint, and if the dimensions are not disjoint, we already + # have our answer. The same is not true for dataset fields only for + # the edge case of dataset types with empty dimensions. + return self._dimensions.isdisjoint(other._dimensions) and ( + self._dataset_fields.keys().isdisjoint(other._dataset_fields) + or all( + fields.isdisjoint(other._dataset_fields[dataset_type]) + for dataset_type, fields in self._dataset_fields.items() + ) + ) + + def copy(self) -> ColumnSet: + result = ColumnSet(self._dimensions) + for element_name, element_fields in self._dimension_fields.items(): + result._dimension_fields[element_name].update(element_fields) + for dataset_type, dataset_fields in self._dataset_fields.items(): + result._dimension_fields[dataset_type].update(dataset_fields) + return result + + def update_dimensions(self, dimensions: DimensionGroup) -> None: + if not dimensions.issubset(self._dimensions): + self._dimensions = dimensions + self._dimension_fields = { + name: self._dimension_fields.get(name, set()) for name in self._dimensions.elements + } + + def update(self, other: ColumnSet) -> None: + self.update_dimensions(other.dimensions) + self._removed_dimension_keys.intersection_update(other._removed_dimension_keys) + for element_name, element_fields in other._dimension_fields.items(): + self._dimension_fields[element_name].update(element_fields) + for dataset_type, dataset_fields in other._dataset_fields.items(): + self._dataset_fields[dataset_type].update(dataset_fields) + + def drop_dimension_keys(self, names: Iterable[str]) -> ColumnSet: + self._removed_dimension_keys.update(names) + return self + + def drop_implied_dimension_keys(self) -> ColumnSet: + self._removed_dimension_keys.update(self._dimensions.implied) + return self + + def restore_dimension_keys(self) -> None: + self._removed_dimension_keys.clear() + + def __iter__(self) -> Iterator[tuple[str, str | None]]: + for dimension_name in self._dimensions.data_coordinate_keys: + if dimension_name not in self._removed_dimension_keys: + yield dimension_name, None + # We iterate over DimensionElements and their DimensionRecord columns + # in order to make sure that's predictable. We might want to extract + # these query results positionally in some contexts. + for element_name in self._dimensions.elements: + element = self._dimensions.universe[element_name] + fields_for_element = self._dimension_fields[element_name] + for spec in element.schema.remainder: + if spec.name in fields_for_element: + yield element_name, spec.name + # We sort dataset types and lexicographically just to keep our queries + # from having any dependence on set-iteration order. + for dataset_type in sorted(self._dataset_fields): + fields_for_dataset_type = self._dataset_fields[dataset_type] + for field in DATASET_FIELD_NAMES: + if field in fields_for_dataset_type: + yield dataset_type, field + + def is_timespan(self, logical_table: str, field: str | None) -> bool: + return field == "timespan" + + @staticmethod + def get_qualified_name(logical_table: str, field: str | None) -> str: + return logical_table if field is None else f"{logical_table}:{field}" + + def get_column_spec(self, logical_table: str, field: str | None) -> column_spec.ColumnSpec: + qualified_name = self.get_qualified_name(logical_table, field) + if field is None: + return self._dimensions.universe.dimensions[logical_table].primary_key.model_copy( + update=dict(name=qualified_name) + ) + if logical_table in self._dimension_fields: + return ( + self._dimensions.universe[logical_table] + .schema.all[field] + .model_copy(update=dict(name=qualified_name)) + ) + match field: + case "dataset_id": + return column_spec.UUIDColumnSpec.model_construct(name=qualified_name, nullable=False) + case "ingest_date": + return column_spec.DateTimeColumnSpec.model_construct(name=qualified_name) + case "run": + # TODO: string length matches the one defined in the + # CollectionManager implementations; we need to find a way to + # avoid hard-coding the value in multiple places. + return column_spec.StringColumnSpec.model_construct( + name=qualified_name, nullable=False, length=128 + ) + case "collection": + return column_spec.StringColumnSpec.model_construct( + name=qualified_name, nullable=False, length=128 + ) + case "rank": + return column_spec.IntColumnSpec.model_construct(name=qualified_name, nullable=False) + case "timespan": + return column_spec.TimespanColumnSpec.model_construct(name=qualified_name, nullable=False) + raise AssertionError(f"Unrecognized column identifiers: {logical_table}, {field}.") diff --git a/python/lsst/daf/butler/queries/tree/_predicate.py b/python/lsst/daf/butler/queries/tree/_predicate.py new file mode 100644 index 0000000000..b0d68bbbae --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_predicate.py @@ -0,0 +1,678 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "Predicate", + "PredicateLeaf", + "LogicalNotOperand", + "PredicateOperands", + "ComparisonOperator", +) + +import itertools +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Annotated, Iterable, Literal, TypeAlias, TypeVar, Union, cast, final + +import pydantic + +from ._base import InvalidQueryTreeError, QueryTreeBase +from ._column_expression import ColumnExpression + +if TYPE_CHECKING: + from ..visitors import PredicateVisitFlags, PredicateVisitor + from ._column_set import ColumnSet + from ._query_tree import QueryTree + +ComparisonOperator: TypeAlias = Literal["==", "!=", "<", ">", ">=", "<=", "overlaps"] + + +_L = TypeVar("_L") +_A = TypeVar("_A") +_O = TypeVar("_O") + + +class PredicateLeafBase(QueryTreeBase, ABC): + """Base class for leaf nodes of the `Predicate` tree. + + This is a closed hierarchy whose concrete, `~typing.final` derived classes + are members of the `PredicateLeaf` union. That union should generally + be used in type annotations rather than the technically-open base class. + """ + + @property + @abstractmethod + def precedence(self) -> int: + """Operator precedence for this operation. + + Lower values bind more tightly, so parentheses are needed when printing + an expression where an operand has a higher value than the expression + itself. + """ + raise NotImplementedError() + + @property + def column_type(self) -> Literal["bool"]: + """A string enumeration value representing the type of the column + expression. + """ + return "bool" + + @abstractmethod + def gather_required_columns(self, columns: ColumnSet) -> None: + """Add any columns required to evaluate this predicate leaf to the + given column set. + + Parameters + ---------- + columns : `ColumnSet` + Set of columns to modify in place. + """ + raise NotImplementedError() + + def invert(self) -> PredicateLeaf: + """Return a new leaf that is the logical not of this one.""" + return LogicalNot.model_construct(operand=cast("LogicalNotOperand", self)) + + @abstractmethod + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + """Invoke the visitor interface. + + Parameters + ---------- + visitor : `PredicateVisitor` + Visitor to invoke a method on. + flags : `PredicateVisitFlags` + Flags that provide information about where this leaf appears in the + larger predicate tree. + + Returns + ------- + result : `object` + Forwarded result from the visitor. + """ + raise NotImplementedError() + + +@final +class Predicate(QueryTreeBase): + """A boolean column expression. + + Notes + ----- + Predicate is the only class representing a boolean column expression that + should be used outside of this module (though the objects it nests appear + in its serialized form and hence are not fully private). It provides + several `classmethod` factories for constructing those nested types inside + a `Predicate` instance, and `PredicateVisitor` subclasses should be used + to process them. + """ + + operands: PredicateOperands + """Nested tuple of operands, with outer items combined via AND and inner + items combined via OR. + """ + + @property + def column_type(self) -> Literal["bool"]: + """A string enumeration value representing the type of the column + expression. + """ + return "bool" + + @classmethod + def from_bool(cls, value: bool) -> Predicate: + """Construct a predicate that always evaluates to `True` or `False`. + + Parameters + ---------- + value : `bool` + Value the predicate should evaluate to. + + Returns + ------- + predicate : `Predicate` + Predicate that evaluates to the given boolean value. + """ + return cls.model_construct(operands=() if value else ((),)) + + @classmethod + def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate: + """Construct a predicate representing a binary comparison between + two non-boolean column expressions. + + Parameters + ---------- + a : `ColumnExpression` + First column expression in the comparison. + operator : `str` + Enumerated string representing the comparison operator to apply. + May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps". + b : `ColumnExpression` + Second column expression in the comparison. + + Returns + ------- + predicate : `Predicate` + Predicate representing the comparison. + """ + return cls._from_leaf(Comparison.model_construct(a=a, operator=operator, b=b)) + + @classmethod + def is_null(cls, operand: ColumnExpression) -> Predicate: + """Construct a predicate that tests whether a column expression is + NULL. + + Parameters + ---------- + operand : `ColumnExpression` + Column expression to test. + + Returns + ------- + predicate : `Predicate` + Predicate representing the NULL check. + """ + return cls._from_leaf(IsNull.model_construct(operand=operand)) + + @classmethod + def in_container(cls, member: ColumnExpression, container: Iterable[ColumnExpression]) -> Predicate: + """Construct a predicate that tests whether one column expression is + a member of a container of other column expressions. + + Parameters + ---------- + member : `ColumnExpression` + Column expression that may be a member of the container. + container : `~collections.abc.Iterable` [ `ColumnExpression` ] + Container of column expressions to test for membership in. + + Returns + ------- + predicate : `Predicate` + Predicate representing the membership test. + """ + return cls._from_leaf(InContainer.model_construct(member=member, container=tuple(container))) + + @classmethod + def in_range( + cls, member: ColumnExpression, start: int = 0, stop: int | None = None, step: int = 1 + ) -> Predicate: + """Construct a predicate that tests whether an integer column + expression is part of a strided range. + + Parameters + ---------- + member : `ColumnExpression` + Column expression that may be a member of the range. + start : `int`, optional + Beginning of the range, inclusive. + stop : `int` or `None`, optional + End of the range, exclusive. + step : `int`, optional + Offset between values in the range. + + Returns + ------- + predicate : `Predicate` + Predicate representing the membership test. + """ + return cls._from_leaf(InRange.model_construct(member=member, start=start, stop=stop, step=step)) + + @classmethod + def in_query_tree( + cls, member: ColumnExpression, column: ColumnExpression, query_tree: QueryTree + ) -> Predicate: + """Construct a predicate that tests whether a column expression is + present in a single-column projection of a query tree. + + Parameters + ---------- + member : `ColumnExpression` + Column expression that may be present in the query. + column : `ColumnExpression` + Column to project from the query. + query_tree : `QueryTree` + Query tree to select from. + + Returns + ------- + predicate : `Predicate` + Predicate representing the membership test. + """ + return cls._from_leaf( + InQueryTree.model_construct(member=member, column=column, query_tree=query_tree) + ) + + def gather_required_columns(self, columns: ColumnSet) -> None: + """Add any columns required to evaluate this predicate to the given + column set. + + Parameters + ---------- + columns : `ColumnSet` + Set of columns to modify in place. + """ + for or_group in self.operands: + for operand in or_group: + operand.gather_required_columns(columns) + + def logical_and(self, *args: Predicate) -> Predicate: + """Construct a predicate representing the logical AND of this predicate + and one or more others. + + Parameters + ---------- + *args : `Predicate` + Other predicates. + + Returns + ------- + predicate : `Predicate` + Predicate representing the logical AND. + """ + operands = self.operands + for arg in args: + operands = self._impl_and(operands, arg.operands) + if not all(operands): + # If any item in operands is an empty tuple (i.e. False), simplify. + operands = ((),) + return Predicate.model_construct(operands=operands) + + def logical_or(self, *args: Predicate) -> Predicate: + """Construct a predicate representing the logical OR of this predicate + and one or more others. + + Parameters + ---------- + *args : `Predicate` + Other predicates. + + Returns + ------- + predicate : `Predicate` + Predicate representing the logical OR. + """ + operands = self.operands + for arg in args: + operands = self._impl_or(operands, arg.operands) + return Predicate.model_construct(operands=operands) + + def logical_not(self) -> Predicate: + """Construct a predicate representing the logical NOT of this + predicate. + + Returns + ------- + predicate : `Predicate` + Predicate representing the logical NOT. + """ + new_operands: PredicateOperands = ((),) + for or_group in self.operands: + new_group: PredicateOperands = () + for leaf in or_group: + new_group = self._impl_and(new_group, ((leaf.invert(),),)) + new_operands = self._impl_or(new_operands, new_group) + return Predicate.model_construct(operands=new_operands) + + def __str__(self) -> str: + and_terms = [] + for or_group in self.operands: + match len(or_group): + case 0: + and_terms.append("False") + case 1: + and_terms.append(str(or_group[0])) + case _: + and_terms.append(f"({' OR '.join(str(operand) for operand in or_group)})") + if not and_terms: + return "True" + return " AND ".join(and_terms) + + def visit(self, visitor: PredicateVisitor[_A, _O, _L]) -> _A: + """Invoke the visitor interface. + + Parameters + ---------- + visitor : `PredicateVisitor` + Visitor to invoke a method on. + + Returns + ------- + result : `object` + Forwarded result from the visitor. + """ + return visitor._visit_logical_and(self.operands) + + @classmethod + def _from_leaf(cls, leaf: PredicateLeaf) -> Predicate: + return cls._from_or_group((leaf,)) + + @classmethod + def _from_or_group(cls, or_group: tuple[PredicateLeaf, ...]) -> Predicate: + return Predicate.model_construct(operands=(or_group,)) + + @classmethod + def _impl_and(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: + return a + b + + @classmethod + def _impl_or(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: + return tuple([a_operand + b_operand for a_operand, b_operand in itertools.product(a, b)]) + + +@final +class LogicalNot(PredicateLeafBase): + """A boolean column expression that inverts its operand.""" + + predicate_type: Literal["not"] = "not" + + operand: LogicalNotOperand + """Upstream boolean expression to invert.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.operand.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 4 + + def __str__(self) -> str: + if self.operand.precedence <= self.precedence: + return f"NOT {self.operand}" + else: + return f"NOT ({self.operand})" + + def invert(self) -> LogicalNotOperand: + # Docstring inherited. + return self.operand + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor._visit_logical_not(self.operand, flags) + + +@final +class IsNull(PredicateLeafBase): + """A boolean column expression that tests whether its operand is NULL.""" + + predicate_type: Literal["is_null"] = "is_null" + + operand: ColumnExpression + """Upstream expression to test.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.operand.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 5 + + def __str__(self) -> str: + if self.operand.precedence <= self.precedence: + return f"{self.operand} IS NULL" + else: + return f"({self.operand}) IS NULL" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor.visit_is_null(self.operand, flags) + + +@final +class Comparison(PredicateLeafBase): + """A boolean columns expression formed by comparing two non-boolean + expressions. + """ + + predicate_type: Literal["comparison"] = "comparison" + + a: ColumnExpression + """Left-hand side expression for the comparison.""" + + b: ColumnExpression + """Right-hand side expression for the comparison.""" + + operator: ComparisonOperator + """Comparison operator.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.a.gather_required_columns(columns) + self.b.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 5 + + def __str__(self) -> str: + a = str(self.a) if self.a.precedence <= self.precedence else f"({self.a})" + b = str(self.b) if self.b.precedence <= self.precedence else f"({self.b})" + return f"{a} {self.operator.upper()} {b}" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor.visit_comparison(self.a, self.operator, self.b, flags) + + @pydantic.model_validator(mode="after") + def _validate_column_types(self) -> Comparison: + if self.a.column_type != self.b.column_type: + raise InvalidQueryTreeError( + f"Column types for comparison {self} do not agree " + f"({self.a.column_type}, {self.b.column_type})." + ) + match (self.operator, self.a.column_type): + case ("==" | "!=", _): + pass + case ("<" | ">" | ">=" | "<=", "int" | "string" | "float" | "datetime"): + pass + case ("overlaps", "region" | "timespan"): + pass + case _: + raise InvalidQueryTreeError( + f"Invalid column type {self.a.column_type} for operator {self.operator!r}." + ) + return self + + +@final +class InContainer(PredicateLeafBase): + """A boolean column expression that tests whether one expression is a + member of an explicit sequence of other expressions. + """ + + predicate_type: Literal["in_container"] = "in_container" + + member: ColumnExpression + """Expression to test for membership.""" + + container: tuple[ColumnExpression, ...] + """Expressions representing the elements of the container.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.member.gather_required_columns(columns) + for item in self.container: + item.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 5 + + def __str__(self) -> str: + m = str(self.member) if self.member.precedence <= self.precedence else f"({self.member})" + return f"{m} IN [{', '.join(str(item) for item in self.container)}]" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor.visit_in_container(self.member, self.container, flags) + + @pydantic.model_validator(mode="after") + def _validate(self) -> InContainer: + if self.member.column_type == "timespan" or self.member.column_type == "region": + raise InvalidQueryTreeError( + f"Timespan or region column {self.member} may not be used in IN expressions." + ) + if not all(item.column_type == self.member.column_type for item in self.container): + raise InvalidQueryTreeError(f"Column types for membership test {self} do not agree.") + return self + + +@final +class InRange(PredicateLeafBase): + """A boolean column expression that tests whether its expression is + included in an integer range. + """ + + predicate_type: Literal["in_range"] = "in_range" + + member: ColumnExpression + """Expression to test for membership.""" + + start: int = 0 + """Inclusive lower bound for the range.""" + + stop: int | None = None + """Exclusive upper bound for the range.""" + + step: int = 1 + """Difference between values in the range.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + self.member.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 5 + + def __str__(self) -> str: + s = f"{self.start if self.start else ''}..{self.stop if self.stop is not None else ''}" + if self.step != 1: + s = f"{s}:{self.step}" + m = str(self.member) if self.member.precedence <= self.precedence else f"({self.member})" + return f"{m} IN {s}" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + return visitor.visit_in_range(self.member, self.start, self.stop, self.step, flags) + + @pydantic.model_validator(mode="after") + def _validate(self) -> InRange: + if self.member.column_type != "int": + raise InvalidQueryTreeError(f"Column {self.member} is not an integer.") + return self + + +@final +class InQueryTree(PredicateLeafBase): + """A boolean column expression that tests whether its expression is + included single-column projection of a relation. + + This is primarily intended to be used on dataset ID columns, but it may + be useful for other columns as well. + """ + + predicate_type: Literal["in_relation"] = "in_relation" + + member: ColumnExpression + """Expression to test for membership.""" + + column: ColumnExpression + """Expression to extract from `query_tree`.""" + + query_tree: QueryTree + """Relation whose rows from `column` represent the container.""" + + def gather_required_columns(self, columns: ColumnSet) -> None: + # Docstring inherited. + # We're only gathering columns from the query_tree this predicate is + # attached to, not `self.column`, which belongs to `self.query_tree`. + self.member.gather_required_columns(columns) + + @property + def precedence(self) -> int: + # Docstring inherited. + return 5 + + def __str__(self) -> str: + m = str(self.member) if self.member.precedence <= self.precedence else f"({self.member})" + c = str(self.column) if self.column.precedence <= self.precedence else f"({self.column})" + return f"{m} IN [{{{self.query_tree}}}.{c}]" + + def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: + # Docstring inherited. + return visitor.visit_in_query_tree(self.member, self.column, self.query_tree, flags) + + @pydantic.model_validator(mode="after") + def _validate_column_types(self) -> InQueryTree: + if self.member.column_type == "timespan" or self.member.column_type == "region": + raise InvalidQueryTreeError( + f"Timespan or region column {self.member} may not be used in IN expressions." + ) + if self.member.column_type != self.column.column_type: + raise InvalidQueryTreeError( + f"Column types for membership test {self} do not agree " + f"({self.member.column_type}, {self.column.column_type})." + ) + + from ._column_set import ColumnSet + + columns_required_in_tree = ColumnSet(self.query_tree.dimensions) + self.column.gather_required_columns(columns_required_in_tree) + if columns_required_in_tree.dimensions != self.query_tree.dimensions: + raise InvalidQueryTreeError( + f"Column {self.column} requires dimensions {columns_required_in_tree.dimensions}, " + f"but query tree only has {self.query_tree.dimensions}." + ) + if not columns_required_in_tree.dataset_fields.keys() <= self.query_tree.datasets.keys(): + raise InvalidQueryTreeError( + f"Column {self.column} requires dataset types " + f"{set(columns_required_in_tree.dataset_fields.keys())} that are not present in query tree." + ) + return self + + +LogicalNotOperand: TypeAlias = Union[ + IsNull, + Comparison, + InContainer, + InRange, + InQueryTree, +] +PredicateLeaf: TypeAlias = Annotated[ + Union[LogicalNotOperand, LogicalNot], pydantic.Field(discriminator="predicate_type") +] + +PredicateOperands: TypeAlias = tuple[tuple[PredicateLeaf, ...], ...] diff --git a/python/lsst/daf/butler/queries/tree/_query_tree.py b/python/lsst/daf/butler/queries/tree/_query_tree.py new file mode 100644 index 0000000000..1f18c49889 --- /dev/null +++ b/python/lsst/daf/butler/queries/tree/_query_tree.py @@ -0,0 +1,349 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "QueryTree", + "make_unit_query_tree", + "make_dimension_query_tree", + "DataCoordinateUploadKey", + "MaterializationKey", + "DatasetSearch", + "DeferredValidationQueryTree", +) + +import uuid +from collections.abc import Mapping +from functools import cached_property +from typing import TypeAlias, final + +import pydantic + +from ...dimensions import DimensionGroup, DimensionUniverse +from ...pydantic_utils import DeferredValidation +from ._base import InvalidQueryTreeError, QueryTreeBase +from ._column_set import ColumnSet +from ._predicate import Predicate + +DataCoordinateUploadKey: TypeAlias = uuid.UUID + +MaterializationKey: TypeAlias = uuid.UUID + + +def make_unit_query_tree(universe: DimensionUniverse) -> QueryTree: + """Make an initial query tree with empty dimensions and a single logical + row. + + This method should be used by `Butler._query` to construct the initial + query tree. This tree is a useful initial state because it is the + identity for joins, in that joining any other query tree to this + query tree yields that query tree. + + Parameters + ---------- + universe : `..DimensionUniverse` + Definitions for all dimensions. + + Returns + ------- + tree : `QueryTree` + A tree with empty dimensions. + """ + return make_dimension_query_tree(universe.empty.as_group()) + + +def make_dimension_query_tree(dimensions: DimensionGroup) -> QueryTree: + """Make an initial query tree with the given dimensions. + + Parameters + ---------- + dimensions : `..DimensionGroup` + Definitions for all dimensions. + + Returns + ------- + tree : `QueryTree` + A tree with the given dimensions. + """ + return QueryTree.model_construct(dimensions=dimensions) + + +@final +class DatasetSearch(QueryTreeBase): + """Information about a dataset search joined into a query tree. + + The dataset type name is the key of the dictionary (in `QueryTree`) where + this type is used as a value. + """ + + collections: tuple[str, ...] + """The collections to search. + + Order matters if this dataset type is later referenced by a `FindFirst` + operation. Collection wildcards are always resolved before being included + in a dataset search. + """ + + dimensions: DimensionGroup + """The dimensions of the dataset type. + + This must match the dimensions of the dataset type as already defined in + the butler database, but this cannot generally be verified when a relation + tree is validated (since it requires a database query) and hence must be + checked later. + """ + + +@final +class QueryTree(QueryTreeBase): + """A declarative, serializable description of a butler query. + + This class's attributes describe the columns that "available" to be + returned or used in ``where`` or ``order_by`` expressions, but it does not + carry information about the columns that are actually included in result + rows, or what kind of butler primitive (e.g. `DataCoordinate` or + `DatasetRef`) those rows might be transformed into. + """ + + dimensions: DimensionGroup + """The dimensions whose keys are joined into the query. + """ + + datasets: Mapping[str, DatasetSearch] = pydantic.Field(default_factory=dict) + """Dataset searches that have been joined into the query.""" + + data_coordinate_uploads: Mapping[DataCoordinateUploadKey, DimensionGroup] = pydantic.Field( + default_factory=dict + ) + """Uploaded tables of data ID values that have been joined into the query. + """ + + materializations: Mapping[MaterializationKey, DimensionGroup] = pydantic.Field(default_factory=dict) + """Tables of result rows from other queries that have been stored + temporarily on the server. + """ + + predicate: Predicate = Predicate.from_bool(True) + """Boolean expression trees whose logical AND defines a row filter.""" + + @cached_property + def join_operand_dimensions(self) -> frozenset[DimensionGroup]: + """A set of sets of the dimensions of all data coordinate uploads, + dataset searches, and materializations. + """ + result: set[DimensionGroup] = set(self.data_coordinate_uploads.values()) + result.update(self.materializations.values()) + for dataset_spec in self.datasets.values(): + result.add(dataset_spec.dimensions) + return frozenset(result) + + def join(self, other: QueryTree) -> QueryTree: + """Return a new tree that represents a join between ``self`` and + ``other``. + + Parameters + ---------- + other : `QueryTree` + Tree to join to this one. + + Returns + ------- + result : `QueryTree` + A new tree that joins ``self`` and ``other``. + + Raises + ------ + InvalidQueryTreeError + Raised if the join is ambiguous or otherwise invalid. + """ + if not self.datasets.keys().isdisjoint(other.datasets.keys()): + raise InvalidQueryTreeError( + "Cannot join when both sides include the same dataset type: " + f"{self.datasets.keys() & other.datasets.keys()}." + ) + return QueryTree.model_construct( + dimensions=self.dimensions | other.dimensions, + datasets={**self.datasets, **other.datasets}, + data_coordinate_uploads={**self.data_coordinate_uploads, **other.data_coordinate_uploads}, + materializations={**self.materializations, **other.materializations}, + predicate=self.predicate.logical_and(other.predicate), + ) + + def join_data_coordinate_upload( + self, key: DataCoordinateUploadKey, dimensions: DimensionGroup + ) -> QueryTree: + """Return a new tree that joins in an uploaded table of data ID values. + + Parameters + ---------- + key : `DataCoordinateUploadKey` + Unique identifier for this upload, as assigned by a `QueryDriver`. + dimensions : `DimensionGroup` + Dimensions of the data IDs. + + Returns + ------- + result : `QueryTree` + A new tree that joins in the data ID table. + """ + if key in self.data_coordinate_uploads: + assert ( + dimensions == self.data_coordinate_uploads[key] + ), f"Different dimensions for the same data coordinate upload key {key}!" + return self + data_coordinate_uploads = dict(self.data_coordinate_uploads) + data_coordinate_uploads[key] = dimensions + return self.model_copy( + update=dict( + dimensions=self.dimensions | dimensions, data_coordinate_uploads=data_coordinate_uploads + ) + ) + + def join_materialization(self, key: MaterializationKey, dimensions: DimensionGroup) -> QueryTree: + """Return a new tree that joins in temporarily stored results from + another query. + + Parameters + ---------- + key : `MaterializationKey` + Unique identifier for this materialization, as assigned by a + `QueryDriver`. + dimensions : `DimensionGroup` + The dimensions stored in the materialization. + + Returns + ------- + result : `QueryTree` + A new tree that joins in the materialization. + """ + if key in self.materializations: + assert ( + dimensions == self.materializations[key] + ), f"Different dimensions for the same materialization {key}!" + return self + materializations = dict(self.materializations) + materializations[key] = dimensions + return self.model_copy( + update=dict(dimensions=self.dimensions | dimensions, materializations=materializations) + ) + + def join_dataset(self, dataset_type: str, spec: DatasetSearch) -> QueryTree: + """Return a new tree joins in a search for a dataset. + + Parameters + ---------- + dataset_type : `str` + Name of dataset type to join in. + spec : `DatasetSpec` + Struct containing the collection search path and dataset type + dimensions. + + Returns + ------- + result : `QueryTree` + A new tree that joins in the dataset search. + + Raises + ------ + InvalidQueryTreeError + Raised if this dataset type is already present in the query tree. + """ + if dataset_type in self.datasets: + if spec != self.datasets[dataset_type]: + raise InvalidQueryTreeError( + f"Dataset type {dataset_type!r} is already present in the query, with different " + "collections and/or dimensions." + ) + return self + datasets = dict(self.datasets) + datasets[dataset_type] = spec + return self.model_copy(update=dict(dimensions=self.dimensions | spec.dimensions, datasets=datasets)) + + def where(self, *terms: Predicate) -> QueryTree: + """Return a new tree that adds row filtering via a boolean column + expression. + + Parameters + ---------- + *terms : `Predicate` + Boolean column expressions that filter rows. Arguments are + combined with logical AND. + + Returns + ------- + result : `QueryTree` + A new tree that with row filtering. + + Raises + ------ + InvalidQueryTreeError + Raised if a column expression requires a dataset column that is not + already present in the query tree. + + Notes + ----- + If an expression references a dimension or dimension element that is + not already present in the query tree, it will be joined in, but + datasets must already be joined into a query tree in order to reference + their fields in expressions. + """ + where_predicate = self.predicate + columns = ColumnSet(self.dimensions) + for where_term in terms: + where_term.gather_required_columns(columns) + where_predicate = where_predicate.logical_and(where_term) + if not (columns.dataset_fields.keys() <= self.datasets.keys()): + raise InvalidQueryTreeError( + f"Cannot reference dataset type(s) {columns.dataset_fields.keys() - self.datasets.keys()} " + "that have not been joined." + ) + return self.model_copy(update=dict(dimensions=columns.dimensions, where_predicate=where_predicate)) + + @pydantic.model_validator(mode="after") + def _validate_join_operands(self) -> QueryTree: + for dimensions in self.join_operand_dimensions: + if not dimensions.issubset(self.dimensions): + raise InvalidQueryTreeError( + f"Dimensions {dimensions} of join operand are not a " + f"subset of the query tree's dimensions {self.dimensions}." + ) + return self + + @pydantic.model_validator(mode="after") + def _validate_required_columns(self) -> QueryTree: + columns = ColumnSet(self.dimensions) + self.predicate.gather_required_columns(columns) + if not columns.dimensions.issubset(self.dimensions): + raise InvalidQueryTreeError("Predicate requires dimensions beyond those in the query tree.") + if not columns.dataset_fields.keys() <= self.datasets.keys(): + raise InvalidQueryTreeError("Predicate requires dataset columns that are not in the query tree.") + return self + + +class DeferredValidationQueryTree(DeferredValidation[QueryTree]): + pass diff --git a/python/lsst/daf/butler/queries/visitors.py b/python/lsst/daf/butler/queries/visitors.py new file mode 100644 index 0000000000..8340389e19 --- /dev/null +++ b/python/lsst/daf/butler/queries/visitors.py @@ -0,0 +1,540 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ( + "ColumnExpressionVisitor", + "PredicateVisitor", + "SimplePredicateVisitor", + "PredicateVisitFlags", +) + +import enum +from abc import abstractmethod +from typing import Generic, TypeVar, final + +from . import tree + +_T = TypeVar("_T") +_L = TypeVar("_L") +_A = TypeVar("_A") +_O = TypeVar("_O") + + +class PredicateVisitFlags(enum.Flag): + """Flags that provide information about the location of a predicate term + in the larger tree. + """ + + HAS_AND_SIBLINGS = enum.auto() + HAS_OR_SIBLINGS = enum.auto() + INVERTED = enum.auto() + + +class ColumnExpressionVisitor(Generic[_T]): + """A visitor interface for traversing a `ColumnExpression` tree. + + Notes + ----- + Unlike `Predicate`, the concrete column expression types need to be + public for various reasons, and hence the visitor interface uses them + directly in its arguments. + + This interface includes `Reversed` (which is part of the `OrderExpression` + union but not the `ColumnExpression` union) because it is simpler to have + just one visitor interface and disable support for it at runtime as + appropriate. + """ + + @abstractmethod + def visit_literal(self, expression: tree.ColumnLiteral) -> _T: + """Visit a column expression that wraps a literal value. + + Parameters + ---------- + expression : `tree.ColumnLiteral` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_dimension_key_reference(self, expression: tree.DimensionKeyReference) -> _T: + """Visit a column expression that represents a dimension column. + + Parameters + ---------- + expression : `tree.DimensionKeyReference` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_dimension_field_reference(self, expression: tree.DimensionFieldReference) -> _T: + """Visit a column expression that represents a dimension record field. + + Parameters + ---------- + expression : `tree.DimensionFieldReference` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_dataset_field_reference(self, expression: tree.DatasetFieldReference) -> _T: + """Visit a column expression that represents a dataset field. + + Parameters + ---------- + expression : `tree.DatasetFieldReference` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_unary_expression(self, expression: tree.UnaryExpression) -> _T: + """Visit a column expression that represents a unary operation. + + Parameters + ---------- + expression : `tree.UnaryExpression` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_binary_expression(self, expression: tree.BinaryExpression) -> _T: + """Visit a column expression that wraps a binary operation. + + Parameters + ---------- + expression : `tree.BinaryExpression` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_reversed(self, expression: tree.Reversed) -> _T: + """Visit a column expression that switches sort order from ascending + to descending. + + Parameters + ---------- + expression : `tree.Reversed` + Expression to visit. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + +class PredicateVisitor(Generic[_A, _O, _L]): + """A visitor interface for traversing a `Predicate`. + + Notes + ----- + The concrete `PredicateLeaf` types are only semi-public (they appear in + the serialized form of a `Predicate`, but their types should not generally + be referenced directly outside of the module in which they are defined. + As a result, visiting these objects unpacks their attributes into the + visit method arguments. + """ + + @abstractmethod + def visit_comparison( + self, + a: tree.ColumnExpression, + operator: tree.ComparisonOperator, + b: tree.ColumnExpression, + flags: PredicateVisitFlags, + ) -> _L: + """Visit a binary comparison between column expressions. + + Parameters + ---------- + a : `tree.ColumnExpression` + First column expression in the comparison. + operator : `str` + Enumerated string representing the comparison operator to apply. + May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps". + b : `tree.ColumnExpression` + Second column expression in the comparison. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_is_null(self, operand: tree.ColumnExpression, flags: PredicateVisitFlags) -> _L: + """Visit a predicate leaf that tests whether a column expression is + NULL. + + Parameters + ---------- + operand : `tree.ColumnExpression` + Column expression to test. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_in_container( + self, + member: tree.ColumnExpression, + container: tuple[tree.ColumnExpression, ...], + flags: PredicateVisitFlags, + ) -> _L: + """Visit a predicate leaf that tests whether a column expression is + a member of a container. + + Parameters + ---------- + member : `tree.ColumnExpression` + Column expression that may be a member of the container. + container : `~collections.abc.Iterable` [ `tree.ColumnExpression` ] + Container of column expressions to test for membership in. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_in_range( + self, + member: tree.ColumnExpression, + start: int, + stop: int | None, + step: int, + flags: PredicateVisitFlags, + ) -> _L: + """Visit a predicate leaf that tests whether a column expression is + a member of an integer range. + + Parameters + ---------- + member : `tree.ColumnExpression` + Column expression that may be a member of the range. + start : `int`, optional + Beginning of the range, inclusive. + stop : `int` or `None`, optional + End of the range, exclusive. + step : `int`, optional + Offset between values in the range. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def visit_in_query_tree( + self, + member: tree.ColumnExpression, + column: tree.ColumnExpression, + query_tree: tree.QueryTree, + flags: PredicateVisitFlags, + ) -> _L: + """Visit a predicate leaf that tests whether a column expression is + a member of a container. + + Parameters + ---------- + member : `tree.ColumnExpression` + Column expression that may be present in the query. + column : `tree.ColumnExpression` + Column to project from the query. + query_tree : `QueryTree` + Query tree to select from. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def apply_logical_not(self, original: tree.PredicateLeaf, result: _L, flags: PredicateVisitFlags) -> _L: + """Apply a logical NOT to the result of visiting an inverted predicate + leaf. + + Parameters + ---------- + original : `PredicateLeaf` + The original operand of the logical NOT operation. + result : `object` + Implementation-defined result of visiting the operand. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. Never has `PredicateVisitFlags.INVERTED` set. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def apply_logical_or( + self, + originals: tuple[tree.PredicateLeaf, ...], + results: tuple[_L, ...], + flags: PredicateVisitFlags, + ) -> _O: + """Apply a logical OR operation to the result of visiting a `tuple` of + predicate leaf objects. + + Parameters + ---------- + originals : `tuple` [ `PredicateLeaf`, ... ] + Original leaf objects in the logical OR. + results : `tuple` [ `object`, ... ] + Result of visiting the leaf objects. + flags : `PredicateVisitFlags` + Information about where this leaf appears in the larger predicate + tree. Never has `PredicateVisitFlags.INVERTED` or + `PredicateVisitFlags.HAS_OR_SIBLINGS` set. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @abstractmethod + def apply_logical_and(self, originals: tree.PredicateOperands, results: tuple[_O, ...]) -> _A: + """Apply a logical AND operation to the result of visiting a nested + `tuple` of predicate leaf objects. + + Parameters + ---------- + originals : `tuple` [ `tuple` [ `PredicateLeaf`, ... ], ... ] + Nested tuple of predicate leaf objects, with inner tuples + corresponding to groups that should be combined with logical OR. + results : `tuple` [ `object`, ... ] + Result of visiting the leaf objects. + + Returns + ------- + result : `object` + Implementation-defined. + """ + raise NotImplementedError() + + @final + def _visit_logical_not(self, operand: tree.LogicalNotOperand, flags: PredicateVisitFlags) -> _L: + return self.apply_logical_not( + operand, operand.visit(self, flags | PredicateVisitFlags.INVERTED), flags + ) + + @final + def _visit_logical_or(self, operands: tuple[tree.PredicateLeaf, ...], flags: PredicateVisitFlags) -> _O: + nested_flags = flags + if len(operands) > 1: + nested_flags |= PredicateVisitFlags.HAS_OR_SIBLINGS + return self.apply_logical_or( + operands, tuple([operand.visit(self, nested_flags) for operand in operands]), flags + ) + + @final + def _visit_logical_and(self, operands: tree.PredicateOperands) -> _A: + if len(operands) > 1: + nested_flags = PredicateVisitFlags.HAS_AND_SIBLINGS + else: + nested_flags = PredicateVisitFlags(0) + return self.apply_logical_and( + operands, tuple([self._visit_logical_or(or_group, nested_flags) for or_group in operands]) + ) + + +class SimplePredicateVisitor( + PredicateVisitor[tree.Predicate | None, tree.Predicate | None, tree.Predicate | None] +): + """An intermediate base class for predicate visitor implementations that + either return `None` or a new `Predicate`. + + Notes + ----- + This class implements all leaf-node visitation methods to return `None`, + which is interpreted by the ``apply*`` method implementations as indicating + that the leaf is unmodified. Subclasses can thus override only certain + visitation methods and either return `None` if there is no result, or + return a replacement `Predicate` to construct a new tree. + """ + + def visit_comparison( + self, + a: tree.ColumnExpression, + operator: tree.ComparisonOperator, + b: tree.ColumnExpression, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def visit_is_null( + self, operand: tree.ColumnExpression, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def visit_in_container( + self, + member: tree.ColumnExpression, + container: tuple[tree.ColumnExpression, ...], + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def visit_in_range( + self, + member: tree.ColumnExpression, + start: int, + stop: int | None, + step: int, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def visit_in_query_tree( + self, + member: tree.ColumnExpression, + column: tree.ColumnExpression, + query_tree: tree.QueryTree, + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + return None + + def apply_logical_not( + self, original: tree.PredicateLeaf, result: tree.Predicate | None, flags: PredicateVisitFlags + ) -> tree.Predicate | None: + # Docstring inherited. + if result is None: + return None + from . import tree + + return tree.Predicate._from_leaf(original).logical_not() + + def apply_logical_or( + self, + originals: tuple[tree.PredicateLeaf, ...], + results: tuple[tree.Predicate | None, ...], + flags: PredicateVisitFlags, + ) -> tree.Predicate | None: + # Docstring inherited. + if all(result is None for result in results): + return None + from . import tree + + return tree.Predicate.from_bool(False).logical_or( + *[ + tree.Predicate._from_leaf(original) if result is None else result + for original, result in zip(originals, results) + ] + ) + + def apply_logical_and( + self, + originals: tree.PredicateOperands, + results: tuple[tree.Predicate | None, ...], + ) -> tree.Predicate | None: + # Docstring inherited. + if all(result is None for result in results): + return None + from . import tree + + return tree.Predicate.from_bool(True).logical_and( + *[ + tree.Predicate._from_or_group(original) if result is None else result + for original, result in zip(originals, results) + ] + ) diff --git a/python/lsst/daf/butler/registry/collections/nameKey.py b/python/lsst/daf/butler/registry/collections/nameKey.py index ccd7d26b6a..7da4390656 100644 --- a/python/lsst/daf/butler/registry/collections/nameKey.py +++ b/python/lsst/daf/butler/registry/collections/nameKey.py @@ -179,6 +179,12 @@ def getParentChains(self, key: str) -> set[str]: parent_names = set(sql_result.scalars().all()) return parent_names + def lookup_name_sql( + self, sql_key: sqlalchemy.ColumnElement[str], sql_from_clause: sqlalchemy.FromClause + ) -> tuple[sqlalchemy.ColumnElement[str], sqlalchemy.FromClause]: + # Docstring inherited. + return sql_key, sql_from_clause + def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]: # Docstring inherited from base class. return self._fetch_by_key(names) diff --git a/python/lsst/daf/butler/registry/collections/synthIntKey.py b/python/lsst/daf/butler/registry/collections/synthIntKey.py index b96a42f0fc..38605a6ad9 100644 --- a/python/lsst/daf/butler/registry/collections/synthIntKey.py +++ b/python/lsst/daf/butler/registry/collections/synthIntKey.py @@ -180,6 +180,17 @@ def getParentChains(self, key: int) -> set[str]: parent_names = set(sql_result.scalars().all()) return parent_names + def lookup_name_sql( + self, sql_key: sqlalchemy.ColumnElement[int], sql_from_clause: sqlalchemy.FromClause + ) -> tuple[sqlalchemy.ColumnElement[str], sqlalchemy.FromClause]: + # Docstring inherited. + return ( + self._tables.collection.c.name, + sql_from_clause.join( + self._tables.collection, onclause=self._tables.collection.c[_KEY_FIELD_SPEC.name] == sql_key + ), + ) + def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]: # Docstring inherited from base class. _LOG.debug("Fetching collection records using names %s.", names) diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py index c67d0b6fb8..7a7a7ae7fb 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py @@ -34,7 +34,7 @@ import datetime from collections.abc import Callable, Iterable, Iterator, Sequence, Set -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import astropy.time import sqlalchemy @@ -46,11 +46,13 @@ from ...._dataset_type import DatasetType from ...._timespan import Timespan from ....dimensions import DataCoordinate +from ....direct_query_driver import SqlBuilder # new query system, server+direct only +from ....queries import tree as qt # new query system, both clients + server from ..._collection_summary import CollectionSummary from ..._collection_type import CollectionType from ..._exceptions import CollectionTypeError, ConflictingDefinitionError from ...interfaces import DatasetRecordStorage -from ...queries import SqlQueryContext +from ...queries import SqlQueryContext # old registry query system from .tables import makeTagTableSpec if TYPE_CHECKING: @@ -552,6 +554,212 @@ def _finish_single_relation( ) return leaf + def make_sql_builder( + self, + collections: Sequence[CollectionRecord], + fields: Set[qt.DatasetFieldName | Literal["collection_key"]], + ) -> SqlBuilder: + # This method largely mimics `make_relation`, but it uses the new query + # system primitives instead of the old one. In terms of the SQL + # queries it builds, there are two more main differences: + # + # - Collection and run columns are now string names rather than IDs. + # This insulates the query result-processing code from collection + # caching and the collection manager subclass details. + # + # - The subquery always has unique rows, which is achieved by using + # SELECT DISTINCT when necessary. + # + collection_types = {collection.type for collection in collections} + assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened." + # + # There are two kinds of table in play here: + # + # - the static dataset table (with the dataset ID, dataset type ID, + # run ID/name, and ingest date); + # + # - the dynamic tags/calibs table (with the dataset ID, dataset type + # type ID, collection ID/name, data ID, and possibly validity + # range). + # + # That means that we might want to return a query against either table + # or a JOIN of both, depending on which quantities the caller wants. + # But the data ID is always included, which means we'll always include + # the tags/calibs table and join in the static dataset table only if we + # need things from it that we can't get from the tags/calibs table. + # + # Note that it's important that we include a WHERE constraint on both + # tables for any column (e.g. dataset_type_id) that is in both when + # it's given explicitly; not doing can prevent the query planner from + # using very important indexes. At present, we don't include those + # redundant columns in the JOIN ON expression, however, because the + # FOREIGN KEY (and its index) are defined only on dataset_id. + tag_sql_builder: SqlBuilder | None = None + if collection_types != {CollectionType.CALIBRATION}: + # We'll need a subquery for the tags table if any of the given + # collections are not a CALIBRATION collection. This intentionally + # also fires when the list of collections is empty as a way to + # create a dummy subquery that we know will fail. + # We give the table an alias because it might appear multiple times + # in the same query, for different dataset types. + tag_sql_builder = SqlBuilder(self._db, self._tags.alias(f"{self.datasetType.name}_tags")) + if "timespan" in fields: + tag_sql_builder.timespans[self.datasetType.name] = ( + self._db.getTimespanRepresentation().fromLiteral(Timespan(None, None)) + ) + tag_sql_builder = self._finish_sql_builder( + tag_sql_builder, + [ + (record, rank) + for rank, record in enumerate(collections) + if record.type is not CollectionType.CALIBRATION + ], + fields, + ) + calib_sql_builder: SqlBuilder | None = None + if CollectionType.CALIBRATION in collection_types: + # If at least one collection is a CALIBRATION collection, we'll + # need a subquery for the calibs table, and could include the + # timespan as a result or constraint. + assert ( + self._calibs is not None + ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection." + calib_sql_builder = SqlBuilder(self._db, self._calibs.alias(f"{self.datasetType.name}_calibs")) + if "timespan" in fields: + calib_sql_builder.timespans[self.datasetType.name] = ( + self._db.getTimespanRepresentation().from_columns(self._calibs.columns) + ) + calib_sql_builder = self._finish_sql_builder( + calib_sql_builder, + [ + (record, rank) + for rank, record in enumerate(collections) + if record.type is CollectionType.CALIBRATION + ], + fields, + ) + # In calibration collections, we need timespan as well as data ID + # to ensure unique rows. + calib_sql_builder.needs_distinct = calib_sql_builder.needs_distinct and "timespan" not in fields + columns = qt.ColumnSet(self.datasetType.dimensions.as_group()) + columns.dataset_fields[self.datasetType.name].update(fields) + columns.drop_implied_dimension_keys() + if tag_sql_builder is not None: + if calib_sql_builder is not None: + # Need a UNION subquery. + return tag_sql_builder.union_subquery([calib_sql_builder], columns) + elif tag_sql_builder.needs_distinct: + # Need a SELECT DISTINCT subquery. + return tag_sql_builder.subquery(columns) + else: + return tag_sql_builder + elif calib_sql_builder is not None: + if calib_sql_builder.needs_distinct: + return calib_sql_builder.subquery(columns) + else: + return calib_sql_builder + else: + raise AssertionError("Branch should be unreachable.") + + def _finish_sql_builder( + self, + sql_builder: SqlBuilder, + collections: Sequence[tuple[CollectionRecord, int]], + fields: Set[qt.DatasetFieldName | Literal["collection_key"]], + ) -> SqlBuilder: + # This method plays the same role as _finish_single_relation in the new + # query system. It is called exactly one or two times by + # make_sql_builder, just as _finish_single_relation is called exactly + # one or two times by make_relation. See make_sql_builder comments for + # what's different. + assert sql_builder.sql_from_clause is not None + run_collections_only = all(record.type is CollectionType.RUN for record, _ in collections) + sql_builder.where_sql(sql_builder.sql_from_clause.c.dataset_type_id == self._dataset_type_id) + dataset_id_col = sql_builder.sql_from_clause.c.dataset_id + collection_col = sql_builder.sql_from_clause.c[self._collections.getCollectionForeignKeyName()] + fields_provided = sql_builder.fields[self.datasetType.name] + # We always constrain and optionally retrieve the collection(s) via the + # tags/calibs table. + if "collection_key" in fields: + sql_builder.fields[self.datasetType.name]["collection_key"] = collection_col + if len(collections) == 1: + only_collection_record, _ = collections[0] + sql_builder.where_sql(collection_col == only_collection_record.key) + if "collection" in fields: + fields_provided["collection"] = sqlalchemy.literal(only_collection_record.name) + elif not collections: + sql_builder.where_sql(sqlalchemy.literal(False)) + if "collection" in fields: + fields_provided["collection"] = sqlalchemy.literal("NO COLLECTIONS") + else: + sql_builder.where_sql(collection_col.in_([collection.key for collection, _ in collections])) + if "collection" in fields: + # Avoid a join to the collection table to get the name by using + # a CASE statement. The SQL will be a bit more verbose but + # more efficient. + fields_provided["collection"] = sqlalchemy.case( + {record.key: record.name for record, _ in collections}, value=collection_col + ) + # Add more column definitions, starting with the data ID. + sql_builder.extract_dimensions(self.datasetType.dimensions.required.names) + # We can always get the dataset_id from the tags/calibs table, even if + # could also get it from the 'static' dataset table. + if "dataset_id" in fields: + fields_provided["dataset_id"] = dataset_id_col + + # It's possible we now have everything we need, from just the + # tags/calibs table. The things we might need to get from the static + # dataset table are the run key and the ingest date. + need_static_table = False + if "run" in fields: + if len(collections) == 1 and run_collections_only: + # If we are searching exactly one RUN collection, we + # know that if we find the dataset in that collection, + # then that's the datasets's run; we don't need to + # query for it. + fields_provided["run"] = sqlalchemy.literal(only_collection_record.name) + elif run_collections_only: + # Once again we can avoid joining to the collection table by + # adding a CASE statement. + fields_provided["run"] = sqlalchemy.case( + {record.key: record.name for record, _ in collections}, + value=self._static.dataset.c[self._runKeyColumn], + ) + need_static_table = True + else: + # Here we can't avoid a join to the collection table, because + # we might find a dataset via something other than its RUN + # collection. + fields_provided["run"], sql_builder.sql_from_clause = self._collections.lookup_name_sql( + self._static.dataset.c[self._runKeyColumn], + sql_builder.sql_from_clause, + ) + need_static_table = True + # Ingest date can only come from the static table. + if "ingest_date" in fields: + fields_provided["ingest_date"] = self._static.dataset.c.ingest_date + need_static_table = True + if need_static_table: + # If we need the static table, join it in via dataset_id. + sql_builder.sql_from_clause = sql_builder.sql_from_clause.join( + self._static.dataset, onclause=(dataset_id_col == self._static.dataset.c.id) + ) + # Also constrain dataset_type_id in static table in case that helps + # generate a better plan. We could also include this in the JOIN ON + # clause, but my guess is that that's a good idea IFF it's in the + # foreign key, and right now it isn't. + sql_builder.where_sql(self._static.dataset.c.dataset_type_id == self._dataset_type_id) + sql_builder.needs_distinct = ( + # If there are multiple collections and we're searching any non-RUN + # collection, we could find the same dataset twice, which would + # yield duplicate rows unless "collection" or "rank" is there to + # make those rows unique. + len(collections) > 1 + and not run_collections_only + and ("collection_key" not in fields) + ) + return sql_builder + def getDataId(self, id: DatasetId) -> DataCoordinate: """Return DataId for a dataset. diff --git a/python/lsst/daf/butler/registry/dimensions/static.py b/python/lsst/daf/butler/registry/dimensions/static.py index 3a3903e855..ceda38b95a 100644 --- a/python/lsst/daf/butler/registry/dimensions/static.py +++ b/python/lsst/daf/butler/registry/dimensions/static.py @@ -30,17 +30,19 @@ import itertools import logging from collections import defaultdict -from collections.abc import Mapping, Sequence, Set +from collections.abc import Iterable, Mapping, Sequence, Set from typing import TYPE_CHECKING, Any import sqlalchemy from lsst.daf.relation import Calculation, ColumnExpression, Join, Relation, sql +from lsst.sphgeom import Region from ... import ddl from ..._column_tags import DimensionKeyColumnTag, DimensionRecordColumnTag from ..._column_type_info import LogicalColumn from ..._named import NamedKeyDict from ...dimensions import ( + DatabaseDimensionElement, DatabaseTopologicalFamily, DataCoordinate, Dimension, @@ -53,11 +55,15 @@ addDimensionForeignKey, ) from ...dimensions.record_cache import DimensionRecordCache +from ...direct_query_driver import Postprocessing, SqlBuilder # Future query system (direct,server). +from ...queries import tree as qt # Future query system (direct,client,server) +from ...queries.overlaps import OverlapsVisitor +from ...queries.visitors import PredicateVisitFlags from .._exceptions import MissingSpatialOverlapError from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext, VersionTuple if TYPE_CHECKING: - from .. import queries + from .. import queries # Current Registry.query* system. # This has to be updated on every schema change @@ -426,6 +432,38 @@ def make_spatial_join_relation( ) return overlaps, needs_refinement + def make_sql_builder(self, element: DimensionElement, fields: Set[str]) -> SqlBuilder: + if element.implied_union_target is not None: + assert not fields, "Dimensions with implied-union storage never have fields." + return self.make_sql_builder(element.implied_union_target, fields).subquery( + qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True + ) + if not element.has_own_table: + raise NotImplementedError(f"Cannot join dimension element {element} with no table.") + table = self._tables[element.name] + result = SqlBuilder(self._db, table) + for dimension_name, column_name in zip(element.required.names, element.schema.required.names): + result.dimension_keys[dimension_name].append(table.columns[column_name]) + result.extract_dimensions(element.implied.names) + for field in fields: + if field == "timespan": + result.timespans[element.name] = self._db.getTimespanRepresentation().from_columns( + table.columns + ) + else: + result.fields[element.name][field] = table.columns[field] + return result + + def process_query_overlaps( + self, + dimensions: DimensionGroup, + predicate: qt.Predicate, + join_operands: Iterable[DimensionGroup], + ) -> tuple[qt.Predicate, SqlBuilder, Postprocessing]: + overlaps_visitor = _CommonSkyPixMediatedOverlapsVisitor(self._db, dimensions, self._overlap_tables) + new_predicate = overlaps_visitor.run(predicate, join_operands) + return new_predicate, overlaps_visitor.sql_builder, overlaps_visitor.postprocessing + def _make_relation( self, element: DimensionElement, @@ -948,3 +986,159 @@ def load(self, key: int) -> DimensionGroup: self.refresh() graph = self._groupsByKey[key] return graph + + +class _CommonSkyPixMediatedOverlapsVisitor(OverlapsVisitor): + def __init__( + self, + db: Database, + dimensions: DimensionGroup, + overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]], + ): + super().__init__(dimensions) + self.sql_builder: SqlBuilder = SqlBuilder(db) + self.postprocessing = Postprocessing() + self.common_skypix = dimensions.universe.commonSkyPix + self.overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]] = overlap_tables + self.common_skypix_overlaps_done: set[DatabaseDimensionElement] = set() + + def visit_spatial_constraint( + self, + element: DimensionElement, + region: Region, + flags: PredicateVisitFlags, + ) -> qt.Predicate | None: + # Reject spatial constraints that are nested inside OR or NOT, because + # the postprocessing needed for those would be a lot harder. + if flags & PredicateVisitFlags.INVERTED or flags & PredicateVisitFlags.HAS_OR_SIBLINGS: + raise NotImplementedError( + "Spatial overlap constraints nested inside OR or NOT are not supported." + ) + # Delegate to super just because that's good practice with + # OverlapVisitor. + super().visit_spatial_constraint(element, region, flags) + match element: + case DatabaseDimensionElement(): + # If this is a database dimension element like tract, patch, or + # visit, we need to: + # - join in the common skypix overlap table for this element; + # - constrain the common skypix index to be inside the + # ranges that overlap the region as a SQL where clause; + # - add postprocessing to reject rows where the database + # dimension element's region doesn't actually overlap the + # region. + self.postprocessing.spatial_where_filtering.append((element, region)) + if self.common_skypix.name in self.dimensions: + # The common skypix dimension should be part of the query + # as a first-class dimension, so we can join in the overlap + # table directly, and fall through to the end of this + # function to construct a Predicate that will turn into the + # SQL WHERE clause we want. + self._join_common_skypix_overlap(element) + skypix = self.common_skypix + else: + # We need to hide the common skypix dimension from the + # larger query, so we make a subquery out of the overlap + # table that embeds the SQL WHERE clause we want and then + # projects out that dimension (with SELECT DISTINCT, to + # avoid introducing duplicate rows into the larger query). + overlap_sql_builder = self._make_common_skypix_overlap_sql_builder(element) + sql_where_or: list[sqlalchemy.ColumnElement[bool]] = [] + sql_skypix_col = overlap_sql_builder.dimension_keys[self.common_skypix.name][0] + for begin, end in self.common_skypix.pixelization.envelope(region): + sql_where_or.append(sqlalchemy.and_(sql_skypix_col >= begin, sql_skypix_col < end)) + overlap_sql_builder.where_sql(sqlalchemy.or_(*sql_where_or)) + self.sql_builder = self.sql_builder.join( + overlap_sql_builder.subquery( + qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True + ) + ) + # Short circuit here since the SQL WHERE clause has already + # been embedded in the subquery. + return qt.Predicate.from_bool(True) + case SkyPixDimension(): + # If this is a skypix dimension, we can do a index-in-ranges + # test directly on that dimension. Note that this doesn't on + # its own guarantee the skypix dimension column will be in the + # query; that'll be the job of the DirectQueryDriver to sort + # out (generally this will require a dataset using that skypix + # dimension to be joined in, unless this is the common skypix + # system). + assert ( + element.name in self.dimensions + ), "QueryTree guarantees dimensions are expanded when constraints are added." + skypix = element + case _: + raise NotImplementedError( + f"Spatial overlap constraint for dimension {element} not supported." + ) + # Convert the region-overlap constraint into a skypix + # index range-membership constraint in SQL... + result = qt.Predicate.from_bool(False) + skypix_col_ref = qt.DimensionKeyReference.model_construct(dimension=skypix) + for begin, end in skypix.pixelization.envelope(region): + result = result.logical_or(qt.Predicate.in_range(skypix_col_ref, start=begin, stop=end)) + return result + + def visit_spatial_join( + self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags + ) -> qt.Predicate | None: + # Reject spatial joins that are nested inside OR or NOT, because the + # postprocessing needed for those would be a lot harder. + if flags & PredicateVisitFlags.INVERTED or flags & PredicateVisitFlags.HAS_OR_SIBLINGS: + raise NotImplementedError("Spatial overlap joins nested inside OR or NOT are not supported.") + # Delegate to super to check for invalid joins and record this + # "connection" for use when seeing whether to add an automatic join + # later. + super().visit_spatial_join(a, b, flags) + match (a, b): + case (self.common_skypix, DatabaseDimensionElement() as b): + self._join_common_skypix_overlap(b) + case (DatabaseDimensionElement() as a, self.common_skypix): + self._join_common_skypix_overlap(a) + case (DatabaseDimensionElement() as a, DatabaseDimensionElement() as b): + if self.common_skypix.name in self.dimensions: + # We want the common skypix dimension to appear in the + # query as a first-class dimension, so just join in the + # two overlap tables directly. + self._join_common_skypix_overlap(a) + self._join_common_skypix_overlap(b) + else: + # We do not want the common skypix system to appear in the + # query or cause duplicate rows, so we join the two overlap + # tables in a subquery that projects out the common skypix + # index column with SELECT DISTINCT. + + self.sql_builder = self.sql_builder.join( + self._make_common_skypix_overlap_sql_builder(a) + .join(self._make_common_skypix_overlap_sql_builder(b)) + .subquery( + qt.ColumnSet(a.minimal_group | b.minimal_group).drop_implied_dimension_keys(), + distinct=True, + ) + ) + # In both cases we add postprocessing to check that the regions + # really do overlap, since overlapping the same common skypix + # tile is necessary but not sufficient for that. + self.postprocessing.spatial_join_filtering.append((a, b)) + case _: + raise NotImplementedError(f"Unsupported combination for spatial join: {a, b}.") + return qt.Predicate.from_bool(True) + + def _join_common_skypix_overlap(self, element: DatabaseDimensionElement) -> None: + if element not in self.common_skypix_overlaps_done: + self.sql_builder = self.sql_builder.join(self._make_common_skypix_overlap_sql_builder(element)) + self.common_skypix_overlaps_done.add(element) + + def _make_common_skypix_overlap_sql_builder(self, element: DatabaseDimensionElement) -> SqlBuilder: + _, overlap_table = self.overlap_tables[element.name] + return self.sql_builder.join( + SqlBuilder(self.sql_builder.db, overlap_table) + .extract_dimensions(element.required.names, skypix_index=self.common_skypix.name) + .where_sql( + sqlalchemy.and_( + overlap_table.c.skypix_system == self.common_skypix.system.name, + overlap_table.c.skypix_level == self.common_skypix.level, + ) + ) + ) diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py index cef7b9741f..418d264460 100644 --- a/python/lsst/daf/butler/registry/interfaces/_collections.py +++ b/python/lsst/daf/butler/registry/interfaces/_collections.py @@ -39,6 +39,8 @@ from collections.abc import Iterable, Set from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar +import sqlalchemy + from ..._timespan import Timespan from .._collection_type import CollectionType from ..wildcards import CollectionWildcard @@ -621,3 +623,27 @@ def update_chain( `~CollectionType.CHAINED` collections in ``children`` first. """ raise NotImplementedError() + + @abstractmethod + def lookup_name_sql( + self, sql_key: sqlalchemy.ColumnElement[_Key], sql_from_clause: sqlalchemy.FromClause + ) -> tuple[sqlalchemy.ColumnElement[str], sqlalchemy.FromClause]: + """Return a SQLAlchemy column and FROM clause that enable a query + to look up a collection name from the key. + + Parameters + ---------- + sql_key : `sqlalchemy.ColumnElement` + SQL column expression that evaluates to the collection key. + sql_from_clause : `sqlalchemy.FromClause` + SQL FROM clause from which ``sql_key`` was obtained. + + Returns + ------- + sql_name : `sqlalchemy.ColumnElement` [ `str` ] + SQL column expression that evalutes to the collection name. + sql_from_clause : `sqlalchemy.FromClause` + SQL FROM clause that includes the given ``sql_from_clause`` and + any table needed to provided ``sql_name``. + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/registry/interfaces/_datasets.py b/python/lsst/daf/butler/registry/interfaces/_datasets.py index abc85d4a05..6c401a9e47 100644 --- a/python/lsst/daf/butler/registry/interfaces/_datasets.py +++ b/python/lsst/daf/butler/registry/interfaces/_datasets.py @@ -32,8 +32,8 @@ __all__ = ("DatasetRecordStorageManager", "DatasetRecordStorage") from abc import ABC, abstractmethod -from collections.abc import Iterable, Iterator, Mapping, Set -from typing import TYPE_CHECKING, Any +from collections.abc import Iterable, Iterator, Mapping, Sequence, Set +from typing import TYPE_CHECKING, Any, Literal from lsst.daf.relation import Relation @@ -45,9 +45,11 @@ from ._versioning import VersionedExtension, VersionTuple if TYPE_CHECKING: + from ...direct_query_driver import SqlBuilder # new query system, server+direct only + from ...queries import tree as qt # new query system, both clients + server from .._caching_context import CachingContext from .._collection_summary import CollectionSummary - from ..queries import SqlQueryContext + from ..queries import SqlQueryContext # old registry query system from ._collections import CollectionManager, CollectionRecord, RunRecord from ._database import Database, StaticTablesContext from ._dimensions import DimensionRecordStorageManager @@ -311,6 +313,14 @@ def make_relation( """ raise NotImplementedError() + @abstractmethod + def make_sql_builder( + self, + collections: Sequence[CollectionRecord], + fields: Set[qt.DatasetFieldName | Literal["collection_key"]], + ) -> SqlBuilder: + raise NotImplementedError() + datasetType: DatasetType """Dataset type whose records this object manages (`DatasetType`). """ diff --git a/python/lsst/daf/butler/registry/interfaces/_dimensions.py b/python/lsst/daf/butler/registry/interfaces/_dimensions.py index c14e19ca29..9f6aff6b6d 100644 --- a/python/lsst/daf/butler/registry/interfaces/_dimensions.py +++ b/python/lsst/daf/butler/registry/interfaces/_dimensions.py @@ -29,7 +29,7 @@ __all__ = ("DimensionRecordStorageManager",) from abc import abstractmethod -from collections.abc import Set +from collections.abc import Iterable, Set from typing import TYPE_CHECKING, Any from lsst.daf.relation import Join, Relation @@ -46,7 +46,9 @@ from ._versioning import VersionedExtension, VersionTuple if TYPE_CHECKING: - from .. import queries + from ...direct_query_driver import Postprocessing, SqlBuilder # Future query system (direct,server). + from ...queries.tree import Predicate # Future query system (direct,client,server). + from .. import queries # Old Registry.query* system. from ._database import Database, StaticTablesContext @@ -357,6 +359,19 @@ def make_spatial_join_relation( """ raise NotImplementedError() + @abstractmethod + def make_sql_builder(self, element: DimensionElement, fields: Set[str]) -> SqlBuilder: + raise NotImplementedError() + + @abstractmethod + def process_query_overlaps( + self, + dimensions: DimensionGroup, + predicate: Predicate, + join_operands: Iterable[DimensionGroup], + ) -> tuple[Predicate, SqlBuilder, Postprocessing]: + raise NotImplementedError() + universe: DimensionUniverse """Universe of all dimensions and dimension elements known to the `Registry` (`DimensionUniverse`).