diff --git a/doc/changes/DM-47947.bugfix.md b/doc/changes/DM-47947.bugfix.md new file mode 100644 index 0000000000..f6b436dd60 --- /dev/null +++ b/doc/changes/DM-47947.bugfix.md @@ -0,0 +1 @@ +Fixed a bug in which projections spatial-join queries (particularly those where the dimensions of the actual regions being compared are not in the query result rows) could return additional records where there actually was no overlap. diff --git a/python/lsst/daf/butler/direct_query_driver/_driver.py b/python/lsst/daf/butler/direct_query_driver/_driver.py index b563898a4f..1121327360 100644 --- a/python/lsst/daf/butler/direct_query_driver/_driver.py +++ b/python/lsst/daf/butler/direct_query_driver/_driver.py @@ -45,7 +45,7 @@ from .._collection_type import CollectionType from .._dataset_type import DatasetType from .._exceptions import InvalidQueryError -from ..dimensions import DataCoordinate, DataIdValue, DimensionGroup, DimensionUniverse +from ..dimensions import DataCoordinate, DataIdValue, DimensionElement, DimensionGroup, DimensionUniverse from ..dimensions.record_cache import DimensionRecordCache from ..queries import tree as qt from ..queries.driver import ( @@ -388,6 +388,7 @@ def count( select_builder = builder.finish_nested() # Replace the columns of the query with just COUNT(*). select_builder.columns = qt.ColumnSet(self._universe.empty) + select_builder.joins.special.clear() count_func: sqlalchemy.ColumnElement[int] = sqlalchemy.func.count() select_builder.joins.special["_ROWCOUNT"] = count_func # Render and run the query. @@ -655,6 +656,9 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis: # it here. postprocessing.spatial_join_filtering.extend(m_state.postprocessing.spatial_join_filtering) postprocessing.spatial_where_filtering.extend(m_state.postprocessing.spatial_where_filtering) + postprocessing.spatial_expression_filtering.extend( + m_state.postprocessing.spatial_expression_filtering + ) # Add data coordinate uploads. joins.data_coordinate_uploads.update(tree.data_coordinate_uploads) # Add dataset_searches and filter out collections that don't have the @@ -849,6 +853,48 @@ def apply_missing_dimension_joins( joins_analysis.predicate.visit(SqlColumnVisitor(select_builder.joins, self)) ) + def project_spatial_join_filtering( + self, + columns: qt.ColumnSet, + postprocessing: Postprocessing, + select_builders: Iterable[SqlSelectBuilder], + ) -> None: + """Transform spatial join postprocessing into expressions that can be + OR'd together via an aggregate function in a GROUP BY. + + This only affects spatial join constraints involving region columns + whose dimensions are being projected away. + + Parameters + ---------- + columns : `.queries.tree.ColumnSet` + Columns that will be included in the final query. + postprocessing : `Postprocessing` + Object that describes post-query processing; modified in place. + select_builders : `~collections.abc.Iterable` [ `SqlSelectBuilder` ] + SQL Builder objects to be modified in place. + """ + kept: list[tuple[DimensionElement, DimensionElement]] = [] + for a, b in postprocessing.spatial_join_filtering: + if a.name not in columns.dimensions.elements or b.name not in columns.dimensions.elements: + expr_name = f"_{a}_OVERLAPS_{b}" + postprocessing.spatial_expression_filtering.append(expr_name) + for select_builder in select_builders: + expr = sqlalchemy.cast( + sqlalchemy.cast( + select_builder.joins.fields[a.name]["region"], type_=sqlalchemy.String + ) + + sqlalchemy.literal("&", type_=sqlalchemy.String) + + sqlalchemy.cast( + select_builder.joins.fields[b.name]["region"], type_=sqlalchemy.String + ), + type_=sqlalchemy.LargeBinary, + ) + select_builder.joins.special[expr_name] = expr + else: + kept.append((a, b)) + postprocessing.spatial_join_filtering = kept + def apply_query_projection( self, select_builder: SqlSelectBuilder, @@ -938,8 +984,8 @@ def apply_query_projection( # the data IDs for those regions are not wholly included in the # results (i.e. we need to postprocess on # visit_detector_region.region, but the output rows don't have - # detector, just visit - so we compute the union of the - # visit_detector region over all matched detectors). + # detector, just visit - so we pack the overlap expression into a + # blob via an aggregate function and interpret it later). if postprocessing.check_validity_match_count: if needs_validity_match_count: select_builder.joins.special[postprocessing.VALIDITY_MATCH_COUNT] = ( @@ -960,11 +1006,27 @@ def apply_query_projection( # might be collapsing the dimensions of the postprocessing # regions. When that happens, we want to apply an aggregate # function to them that computes the union of the regions that - # are grouped together. + # are grouped together. Note that this should only happen for + # constraints that involve a "given", external-to-the-database + # region (postprocessing.spatial_where_filtering); join + # constraints that need aggregates should have already been + # transformed in advance. select_builder.joins.fields[element.name]["region"] = ddl.Base64Region.union_aggregate( select_builder.joins.fields[element.name]["region"] ) have_aggregates = True + # Postprocessing spatial join constraints where at least one region's + # dimensions are being projected away will have already been turned + # into the kind of expression that sphgeom.Region.decodeOverlapsBase64 + # processes. We can just apply an aggregate function to these. Note + # that we don't do this to other constraints in order to minimize + # duplicate fetches of the same region blob. + for expr_name in postprocessing.spatial_expression_filtering: + select_builder.joins.special[expr_name] = sqlalchemy.cast( + sqlalchemy.func.aggregate_strings(select_builder.joins.special[expr_name], "|"), + type_=sqlalchemy.LargeBinary, + ) + have_aggregates = True # All dimension record fields are derived fields. for element_name, fields_for_element in projection_columns.dimension_fields.items(): diff --git a/python/lsst/daf/butler/direct_query_driver/_postprocessing.py b/python/lsst/daf/butler/direct_query_driver/_postprocessing.py index 244b11be6c..db577ec6a9 100644 --- a/python/lsst/daf/butler/direct_query_driver/_postprocessing.py +++ b/python/lsst/daf/butler/direct_query_driver/_postprocessing.py @@ -59,6 +59,7 @@ class Postprocessing: def __init__(self) -> None: self.spatial_join_filtering = [] self.spatial_where_filtering = [] + self.spatial_expression_filtering = [] self.check_validity_match_count: bool = False self._limit: int | None = None @@ -79,6 +80,12 @@ def __init__(self) -> None: non-overlap pair will be filtered out. """ + spatial_expression_filtering: list[str] + """The names of calculated columns that can be parsed by + `lsst.sphgeom.Region.decodeOverlapsBase64` into a `bool` or `None` that + indicates whether regions definitely overlap. + """ + check_validity_match_count: bool """If `True`, result rows will include a special column that counts the number of matching datasets in each collection for each data ID, and @@ -104,7 +111,9 @@ def limit(self, value: int | None) -> None: self._limit = value def __bool__(self) -> bool: - return bool(self.spatial_join_filtering or self.spatial_where_filtering) + return bool( + self.spatial_join_filtering or self.spatial_where_filtering or self.spatial_expression_filtering + ) def gather_columns_required(self, columns: qt.ColumnSet) -> None: """Add all columns required to perform postprocessing to the given @@ -198,8 +207,10 @@ def apply(self, rows: Iterable[sqlalchemy.Row]) -> Iterable[sqlalchemy.Row]: for row in rows: m = row._mapping # Skip rows where at least one couple of regions do not overlap. - if any(m[a].overlaps(m[b]) is False for a, b in joins) or any( - m[field].overlaps(region) is False for field, region in where + if ( + any(Region.decodeOverlapsBase64(m[c]) is False for c in self.spatial_expression_filtering) + or any(m[a].overlaps(m[b]) is False for a, b in joins) + or any(m[field].overlaps(region) is False for field, region in where) ): continue if self.check_validity_match_count and m[self.VALIDITY_MATCH_COUNT] > 1: diff --git a/python/lsst/daf/butler/direct_query_driver/_query_builder.py b/python/lsst/daf/butler/direct_query_driver/_query_builder.py index 3a2e8346f9..795c71916b 100644 --- a/python/lsst/daf/butler/direct_query_driver/_query_builder.py +++ b/python/lsst/daf/butler/direct_query_driver/_query_builder.py @@ -35,6 +35,7 @@ ) import dataclasses +import itertools from abc import ABC, abstractmethod from collections.abc import Iterable, Set from typing import TYPE_CHECKING, Literal, TypeVar, overload @@ -382,6 +383,9 @@ def apply_joins(self, driver: DirectQueryDriver) -> None: def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.OrderExpression]) -> None: # Docstring inherited. + driver.project_spatial_join_filtering( + self.projection_columns, self.postprocessing, [self._select_builder] + ) driver.apply_query_projection( self._select_builder, self.postprocessing, @@ -635,6 +639,11 @@ def apply_joins(self, driver: DirectQueryDriver) -> None: def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.OrderExpression]) -> None: # Docstring inherited. + driver.project_spatial_join_filtering( + self.projection_columns, + self.postprocessing, + itertools.chain.from_iterable(union_term.select_builders for union_term in self.union_terms), + ) for union_term in self.union_terms: for builder in union_term.select_builders: driver.apply_query_projection( diff --git a/python/lsst/daf/butler/direct_query_driver/_sql_builders.py b/python/lsst/daf/butler/direct_query_driver/_sql_builders.py index ca62ed47d4..1a537d5d28 100644 --- a/python/lsst/daf/butler/direct_query_driver/_sql_builders.py +++ b/python/lsst/daf/butler/direct_query_driver/_sql_builders.py @@ -422,6 +422,8 @@ def extract_columns( self.fields[element.name]["region"] = column_collection[ self.db.name_shrinker.shrink(columns.get_qualified_name(element.name, "region")) ] + for name in postprocessing.spatial_expression_filtering: + self.special[name] = column_collection[name] if postprocessing.check_validity_match_count: self.special[postprocessing.VALIDITY_MATCH_COUNT] = column_collection[ postprocessing.VALIDITY_MATCH_COUNT @@ -670,6 +672,8 @@ def make_table_spec( db.name_shrinker.shrink(columns.get_qualified_name(element.name, "region")) ) ) + for name in postprocessing.spatial_expression_filtering: + results.fields.add(ddl.FieldSpec(name, dtype=sqlalchemy.types.LargeBinary, nullable=True)) if not results.fields: results.fields.add( ddl.FieldSpec(name=SqlSelectBuilder.EMPTY_COLUMNS_NAME, dtype=SqlSelectBuilder.EMPTY_COLUMNS_TYPE) diff --git a/python/lsst/daf/butler/tests/butler_queries.py b/python/lsst/daf/butler/tests/butler_queries.py index 050816848e..0c74b35a4a 100644 --- a/python/lsst/daf/butler/tests/butler_queries.py +++ b/python/lsst/daf/butler/tests/butler_queries.py @@ -759,6 +759,18 @@ def test_spatial_overlaps(self) -> None: [1, 2, 3], has_postprocessing=True, ) + # Same as above, but with a materialization. + self.check_detector_records( + query.where( + _x.visit_detector_region.region.overlaps(_x.patch.region), + tract=0, + patch=4, + ) + .materialize() + .dimension_records("detector"), + [1, 2, 3], + has_postprocessing=True, + ) # Query for that patch's region and express the previous query as # a region-constraint instead of a spatial join. (patch_record,) = query.where(tract=0, patch=4).dimension_records("patch") @@ -777,6 +789,21 @@ def test_spatial_overlaps(self) -> None: ), ids=[1, 2, 3], ) + # Query for detectors where a patch/visit+detector overlap is + # satisfied, in the case where there are no rows with an overlap, + # but the union of the patch regions overlaps the union of the + # visit+detector regions. + self.check_detector_records( + query.where( + _x.visit_detector_region.region.overlaps(_x.patch.region), + _x.any( + _x.all(_x.tract == 1, _x.visit == 1), + _x.all(_x.tract == 0, _x.patch == 0, _x.visit == 2), + ), + ).dimension_records("detector"), + [], + has_postprocessing=True, + ) # Combine postprocessing with order_by and limit. self.check_detector_records( query.where(