From 5903c60979dc37c0767784ecc7abbc023c326551 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Tue, 28 Jan 2025 14:29:57 -0800 Subject: [PATCH] Add Query option to suppress DISTINCT in skypix overlaps. This is a non-public API for now, solely for graph builder use. --- .../daf/butler/direct_query_driver/_driver.py | 20 +++++++++++++---- python/lsst/daf/butler/queries/_query.py | 22 ++++++++++++++++--- python/lsst/daf/butler/queries/driver.py | 4 ++++ .../lsst/daf/butler/queries/result_specs.py | 5 +++++ .../daf/butler/registry/dimensions/static.py | 10 ++++++--- .../butler/registry/interfaces/_dimensions.py | 4 ++++ .../daf/butler/remote_butler/_query_driver.py | 2 ++ .../server/handlers/_external_query.py | 1 + .../daf/butler/remote_butler/server_models.py | 1 + .../lsst/daf/butler/tests/butler_queries.py | 20 +++++++++++++++++ tests/test_query_interface.py | 1 + 11 files changed, 80 insertions(+), 10 deletions(-) diff --git a/python/lsst/daf/butler/direct_query_driver/_driver.py b/python/lsst/daf/butler/direct_query_driver/_driver.py index faf1586e17..486e85e8ea 100644 --- a/python/lsst/daf/butler/direct_query_driver/_driver.py +++ b/python/lsst/daf/butler/direct_query_driver/_driver.py @@ -217,6 +217,7 @@ def execute(self, result_spec: ResultSpec, tree: qt.QueryTree) -> Iterator[Resul final_columns=result_spec.get_result_columns(), order_by=result_spec.order_by, find_first_dataset=result_spec.find_first_dataset, + allow_duplicate_overlaps=result_spec.allow_duplicate_overlaps, ) sql_select, sql_columns = builder.finish_select() if result_spec.order_by: @@ -290,12 +291,15 @@ def materialize( tree: qt.QueryTree, dimensions: DimensionGroup, datasets: frozenset[str], + allow_duplicate_overlaps: bool = False, key: qt.MaterializationKey | None = None, ) -> qt.MaterializationKey: # Docstring inherited. if self._exit_stack is None: raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.") - plan = self.build_query(tree, qt.ColumnSet(dimensions)) + plan = self.build_query( + tree, qt.ColumnSet(dimensions), allow_duplicate_overlaps=allow_duplicate_overlaps + ) # Current implementation ignores 'datasets' aside from remembering # them, because figuring out what to put in the temporary table for # them is tricky, especially if calibration collections are involved. @@ -403,7 +407,7 @@ def count( def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool: # Docstring inherited. - builder = self.build_query(tree, qt.ColumnSet(tree.dimensions)) + builder = self.build_query(tree, qt.ColumnSet(tree.dimensions), allow_duplicate_overlaps=True) if not all(d.collection_records for d in builder.joins_analysis.datasets.values()): return False if not execute: @@ -449,6 +453,7 @@ def build_query( order_by: Iterable[qt.OrderExpression] = (), find_first_dataset: str | qt.AnyDatasetType | None = None, analyze_only: bool = False, + allow_duplicate_overlaps: bool = False, ) -> QueryBuilder: """Convert a query description into a nearly-complete builder object for the SQL version of that query. @@ -472,6 +477,9 @@ def build_query( builder, but do not call methods that build its SQL form. This can be useful for obtaining diagnostic information about the query that would be generated. + allow_duplicate_overlaps : `bool`, optional + If set to `True` then query will be allowed to generate + non-distinct rows for spatial overlaps. Returns ------- @@ -544,7 +552,7 @@ def build_query( # SqlSelectBuilder and Postprocessing with spatial/temporal constraints # potentially transformed by the dimensions manager (but none of the # rest of the analysis reflected in that SqlSelectBuilder). - query_tree_analysis = self._analyze_query_tree(tree) + query_tree_analysis = self._analyze_query_tree(tree, allow_duplicate_overlaps) # The "projection" columns differ from the final columns by not # omitting any dimension keys (this keeps queries for different result # types more similar during construction), including any columns needed @@ -591,7 +599,7 @@ def build_query( builder.apply_find_first(self) return builder - def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis: + def _analyze_query_tree(self, tree: qt.QueryTree, allow_duplicate_overlaps: bool) -> QueryTreeAnalysis: """Analyze a `.queries.tree.QueryTree` as the first step in building a SQL query. @@ -605,6 +613,9 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis: tree_analysis : `QueryTreeAnalysis` Struct containing additional information need to build the joins stage of a query. + allow_duplicate_overlaps : `bool`, optional + If set to `True` then query will be allowed to generate + non-distinct rows for spatial overlaps. Notes ----- @@ -634,6 +645,7 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis: tree.predicate, tree.get_joined_dimension_groups(), collection_analysis.calibration_dataset_types, + allow_duplicate_overlaps, ) # Extract the data ID implied by the predicate; we can use the governor # dimensions in that to constrain the collections we search for diff --git a/python/lsst/daf/butler/queries/_query.py b/python/lsst/daf/butler/queries/_query.py index 086ad7b3cf..0471b3efc9 100644 --- a/python/lsst/daf/butler/queries/_query.py +++ b/python/lsst/daf/butler/queries/_query.py @@ -106,6 +106,12 @@ def __init__(self, driver: QueryDriver, tree: QueryTree | None = None): tree = make_identity_query_tree(driver.universe) super().__init__(driver, tree) + # If ``_allow_duplicate_overlaps`` is set to `True` then query will be + # allowed to generate non-distinct rows for spatial overlaps. This is + # not a part of public API for now, to be used by graph builder as + # optimization. + self._allow_duplicate_overlaps: bool = False + @property def constraint_dataset_types(self) -> Set[str]: """The names of all dataset types joined into the query. @@ -218,7 +224,11 @@ def data_ids( dimensions = self._driver.universe.conform(dimensions) if not dimensions <= self._tree.dimensions: tree = tree.join_dimensions(dimensions) - result_spec = DataCoordinateResultSpec(dimensions=dimensions, include_dimension_records=False) + result_spec = DataCoordinateResultSpec( + dimensions=dimensions, + include_dimension_records=False, + allow_duplicate_overlaps=self._allow_duplicate_overlaps, + ) return DataCoordinateQueryResults(self._driver, tree, result_spec) def datasets( @@ -284,6 +294,7 @@ def datasets( storage_class_name=storage_class_name, include_dimension_records=False, find_first=find_first, + allow_duplicate_overlaps=self._allow_duplicate_overlaps, ) return DatasetRefQueryResults(self._driver, tree=query._tree, spec=spec) @@ -308,7 +319,9 @@ def dimension_records(self, element: str) -> DimensionRecordQueryResults: tree = self._tree if element not in tree.dimensions.elements: tree = tree.join_dimensions(self._driver.universe[element].minimal_group) - result_spec = DimensionRecordResultSpec(element=self._driver.universe[element]) + result_spec = DimensionRecordResultSpec( + element=self._driver.universe[element], allow_duplicate_overlaps=self._allow_duplicate_overlaps + ) return DimensionRecordQueryResults(self._driver, tree, result_spec) def general( @@ -445,6 +458,7 @@ def general( dimension_fields=dimension_fields_dict, dataset_fields=dataset_fields_dict, find_first=find_first, + allow_duplicate_overlaps=self._allow_duplicate_overlaps, ) return GeneralQueryResults(self._driver, tree=tree, spec=result_spec) @@ -495,7 +509,9 @@ def materialize( dimensions = self._tree.dimensions else: dimensions = self._driver.universe.conform(dimensions) - key = self._driver.materialize(self._tree, dimensions, datasets) + key = self._driver.materialize( + self._tree, dimensions, datasets, allow_duplicate_overlaps=self._allow_duplicate_overlaps + ) tree = make_identity_query_tree(self._driver.universe).join_materialization( key, dimensions=dimensions ) diff --git a/python/lsst/daf/butler/queries/driver.py b/python/lsst/daf/butler/queries/driver.py index 22703c73c1..a27381a4a1 100644 --- a/python/lsst/daf/butler/queries/driver.py +++ b/python/lsst/daf/butler/queries/driver.py @@ -209,6 +209,7 @@ def materialize( tree: QueryTree, dimensions: DimensionGroup, datasets: frozenset[str], + allow_duplicate_overlaps: bool = False, ) -> MaterializationKey: """Execute a query tree, saving results to temporary storage for use in later queries. @@ -222,6 +223,9 @@ def materialize( datasets : `frozenset` [ `str` ] Names of dataset types whose ID columns may be materialized. It is implementation-defined whether they actually are. + allow_duplicate_overlaps : `bool`, optional + If set to `True` then query will be allowed to generate + non-distinct rows for spatial overlaps. Returns ------- diff --git a/python/lsst/daf/butler/queries/result_specs.py b/python/lsst/daf/butler/queries/result_specs.py index baf131d865..462fc12bf5 100644 --- a/python/lsst/daf/butler/queries/result_specs.py +++ b/python/lsst/daf/butler/queries/result_specs.py @@ -62,6 +62,11 @@ class ResultSpecBase(pydantic.BaseModel, ABC): limit: int | None = None """Maximum number of rows to return, or `None` for no bound.""" + allow_duplicate_overlaps: bool = False + """If set to True the queries are allowed to returnd duplicate rows for + spatial overlaps. + """ + def validate_tree(self, tree: QueryTree) -> None: """Check that this result object is consistent with a query tree. diff --git a/python/lsst/daf/butler/registry/dimensions/static.py b/python/lsst/daf/butler/registry/dimensions/static.py index e18c1cb94b..bfd6cbcf2e 100644 --- a/python/lsst/daf/butler/registry/dimensions/static.py +++ b/python/lsst/daf/butler/registry/dimensions/static.py @@ -487,9 +487,10 @@ def process_query_overlaps( predicate: qt.Predicate, join_operands: Iterable[DimensionGroup], calibration_dataset_types: Set[str | qt.AnyDatasetType], + allow_duplicates: bool = False, ) -> tuple[qt.Predicate, SqlSelectBuilder, Postprocessing]: overlaps_visitor = _CommonSkyPixMediatedOverlapsVisitor( - self._db, dimensions, calibration_dataset_types, self._overlap_tables + self._db, dimensions, calibration_dataset_types, self._overlap_tables, allow_duplicates ) new_predicate = overlaps_visitor.run(predicate, join_operands) return new_predicate, overlaps_visitor.builder, overlaps_visitor.postprocessing @@ -1025,6 +1026,7 @@ def __init__( dimensions: DimensionGroup, calibration_dataset_types: Set[str | qt.AnyDatasetType], overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]], + allow_duplicates: bool, ): super().__init__(dimensions, calibration_dataset_types) self.builder: SqlSelectBuilder = SqlJoinsBuilder(db=db).to_select_builder(qt.ColumnSet(dimensions)) @@ -1032,6 +1034,7 @@ def __init__( self.common_skypix = dimensions.universe.commonSkyPix self.overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]] = overlap_tables self.common_skypix_overlaps_done: set[DatabaseDimensionElement] = set() + self.allow_duplicates = allow_duplicates def visit_spatial_constraint( self, @@ -1081,7 +1084,8 @@ def visit_spatial_constraint( joins_builder.where(sqlalchemy.or_(*sql_where_or)) self.builder.join( joins_builder.to_select_builder( - qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True + qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), + distinct=not self.allow_duplicates, ).into_joins_builder(postprocessing=None) ) # Short circuit here since the SQL WHERE clause has already @@ -1145,7 +1149,7 @@ def visit_spatial_join( .join(self._make_common_skypix_overlap_joins_builder(b)) .to_select_builder( qt.ColumnSet(a.minimal_group | b.minimal_group).drop_implied_dimension_keys(), - distinct=True, + distinct=not self.allow_duplicates, ) .into_joins_builder(postprocessing=None) ) diff --git a/python/lsst/daf/butler/registry/interfaces/_dimensions.py b/python/lsst/daf/butler/registry/interfaces/_dimensions.py index a6cffd5f2b..27450af8eb 100644 --- a/python/lsst/daf/butler/registry/interfaces/_dimensions.py +++ b/python/lsst/daf/butler/registry/interfaces/_dimensions.py @@ -402,6 +402,7 @@ def process_query_overlaps( predicate: Predicate, join_operands: Iterable[DimensionGroup], calibration_dataset_types: Set[str | AnyDatasetType], + allow_duplicates: bool = False, ) -> tuple[Predicate, SqlSelectBuilder, Postprocessing]: """Process a query's WHERE predicate and dimensions to handle spatial and temporal overlaps. @@ -424,6 +425,9 @@ def process_query_overlaps( `..queries.tree.AnyDatasetType` ] The names of dataset types that have been joined into the query via a search that includes at least one calibration collection. + allow_duplicates : `bool` + If set to `True` then query will be allowed to return non-distinct + rows. Returns ------- diff --git a/python/lsst/daf/butler/remote_butler/_query_driver.py b/python/lsst/daf/butler/remote_butler/_query_driver.py index 47cec7c248..0e6aa4d91b 100644 --- a/python/lsst/daf/butler/remote_butler/_query_driver.py +++ b/python/lsst/daf/butler/remote_butler/_query_driver.py @@ -163,6 +163,7 @@ def materialize( tree: QueryTree, dimensions: DimensionGroup, datasets: frozenset[str], + allow_duplicate_overlaps: bool = False, ) -> MaterializationKey: key = uuid4() self._stored_query_inputs.append( @@ -171,6 +172,7 @@ def materialize( tree=SerializedQueryTree(tree.model_copy(deep=True)), dimensions=dimensions.to_simple(), datasets=datasets, + allow_duplicate_overlaps=allow_duplicate_overlaps, ), ) return key diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py index 159994446c..72f03856cd 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py @@ -195,6 +195,7 @@ def _get_query_context(factory: Factory, query: QueryInputs) -> Iterator[_QueryC DimensionGroup.from_simple(input.dimensions, butler.dimensions), frozenset(input.datasets), key=input.key, + allow_duplicate_overlaps=input.allow_duplicate_overlaps, ) elif input.type == "upload": driver.upload_data_coordinates( diff --git a/python/lsst/daf/butler/remote_butler/server_models.py b/python/lsst/daf/butler/remote_butler/server_models.py index 4f76e605bc..b04c935efc 100644 --- a/python/lsst/daf/butler/remote_butler/server_models.py +++ b/python/lsst/daf/butler/remote_butler/server_models.py @@ -243,6 +243,7 @@ class MaterializedQuery(pydantic.BaseModel): tree: SerializedQueryTree dimensions: SerializedDimensionGroup datasets: list[str] + allow_duplicate_overlaps: bool = False class DataCoordinateUpload(pydantic.BaseModel): diff --git a/python/lsst/daf/butler/tests/butler_queries.py b/python/lsst/daf/butler/tests/butler_queries.py index 08007e4d3a..5faaedfae7 100644 --- a/python/lsst/daf/butler/tests/butler_queries.py +++ b/python/lsst/daf/butler/tests/butler_queries.py @@ -690,6 +690,26 @@ def test_dataset_constrained_record_query(self) -> None: doomed=True, ) + def test_duplicate_overlaps(self) -> None: + """Test for query option that enables duplicate rows in queries that + use skypix overalps. + """ + butler = self.make_butler("base.yaml", "spatial.yaml") + butler.registry.defaults = RegistryDefaults(instrument="Cam1", skymap="SkyMap1") + with butler.query() as query: + + data_ids = list(query.data_ids(["visit", "detector", "patch"]).where(visit=1, detector=1)) + self.assertCountEqual( + [(data_id["tract"], data_id["patch"]) for data_id in data_ids], [(0, 0), (0, 2)] + ) + + query._allow_duplicate_overlaps = True + data_ids = list(query.data_ids(["visit", "detector", "patch"]).where(visit=1, detector=1)) + self.assertCountEqual( + [(data_id["tract"], data_id["patch"]) for data_id in data_ids], + [(0, 0), (0, 0), (0, 2), (0, 2)], + ) + def test_spatial_overlaps(self) -> None: """Test queries for dimension records with spatial overlaps. diff --git a/tests/test_query_interface.py b/tests/test_query_interface.py index 7ca492fbcb..a9cb4eba30 100644 --- a/tests/test_query_interface.py +++ b/tests/test_query_interface.py @@ -369,6 +369,7 @@ def materialize( tree: qt.QueryTree, dimensions: DimensionGroup, datasets: frozenset[str], + allow_duplicate_overlaps: bool = False, ) -> qd.MaterializationKey: key = uuid.uuid4() self.materializations[key] = (tree, dimensions, datasets)