diff --git a/python/lsst/daf/butler/direct_query_driver/_driver.py b/python/lsst/daf/butler/direct_query_driver/_driver.py index b563898a4f..1121327360 100644 --- a/python/lsst/daf/butler/direct_query_driver/_driver.py +++ b/python/lsst/daf/butler/direct_query_driver/_driver.py @@ -45,7 +45,7 @@ from .._collection_type import CollectionType from .._dataset_type import DatasetType from .._exceptions import InvalidQueryError -from ..dimensions import DataCoordinate, DataIdValue, DimensionGroup, DimensionUniverse +from ..dimensions import DataCoordinate, DataIdValue, DimensionElement, DimensionGroup, DimensionUniverse from ..dimensions.record_cache import DimensionRecordCache from ..queries import tree as qt from ..queries.driver import ( @@ -388,6 +388,7 @@ def count( select_builder = builder.finish_nested() # Replace the columns of the query with just COUNT(*). select_builder.columns = qt.ColumnSet(self._universe.empty) + select_builder.joins.special.clear() count_func: sqlalchemy.ColumnElement[int] = sqlalchemy.func.count() select_builder.joins.special["_ROWCOUNT"] = count_func # Render and run the query. @@ -655,6 +656,9 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis: # it here. postprocessing.spatial_join_filtering.extend(m_state.postprocessing.spatial_join_filtering) postprocessing.spatial_where_filtering.extend(m_state.postprocessing.spatial_where_filtering) + postprocessing.spatial_expression_filtering.extend( + m_state.postprocessing.spatial_expression_filtering + ) # Add data coordinate uploads. joins.data_coordinate_uploads.update(tree.data_coordinate_uploads) # Add dataset_searches and filter out collections that don't have the @@ -849,6 +853,48 @@ def apply_missing_dimension_joins( joins_analysis.predicate.visit(SqlColumnVisitor(select_builder.joins, self)) ) + def project_spatial_join_filtering( + self, + columns: qt.ColumnSet, + postprocessing: Postprocessing, + select_builders: Iterable[SqlSelectBuilder], + ) -> None: + """Transform spatial join postprocessing into expressions that can be + OR'd together via an aggregate function in a GROUP BY. + + This only affects spatial join constraints involving region columns + whose dimensions are being projected away. + + Parameters + ---------- + columns : `.queries.tree.ColumnSet` + Columns that will be included in the final query. + postprocessing : `Postprocessing` + Object that describes post-query processing; modified in place. + select_builders : `~collections.abc.Iterable` [ `SqlSelectBuilder` ] + SQL Builder objects to be modified in place. + """ + kept: list[tuple[DimensionElement, DimensionElement]] = [] + for a, b in postprocessing.spatial_join_filtering: + if a.name not in columns.dimensions.elements or b.name not in columns.dimensions.elements: + expr_name = f"_{a}_OVERLAPS_{b}" + postprocessing.spatial_expression_filtering.append(expr_name) + for select_builder in select_builders: + expr = sqlalchemy.cast( + sqlalchemy.cast( + select_builder.joins.fields[a.name]["region"], type_=sqlalchemy.String + ) + + sqlalchemy.literal("&", type_=sqlalchemy.String) + + sqlalchemy.cast( + select_builder.joins.fields[b.name]["region"], type_=sqlalchemy.String + ), + type_=sqlalchemy.LargeBinary, + ) + select_builder.joins.special[expr_name] = expr + else: + kept.append((a, b)) + postprocessing.spatial_join_filtering = kept + def apply_query_projection( self, select_builder: SqlSelectBuilder, @@ -938,8 +984,8 @@ def apply_query_projection( # the data IDs for those regions are not wholly included in the # results (i.e. we need to postprocess on # visit_detector_region.region, but the output rows don't have - # detector, just visit - so we compute the union of the - # visit_detector region over all matched detectors). + # detector, just visit - so we pack the overlap expression into a + # blob via an aggregate function and interpret it later). if postprocessing.check_validity_match_count: if needs_validity_match_count: select_builder.joins.special[postprocessing.VALIDITY_MATCH_COUNT] = ( @@ -960,11 +1006,27 @@ def apply_query_projection( # might be collapsing the dimensions of the postprocessing # regions. When that happens, we want to apply an aggregate # function to them that computes the union of the regions that - # are grouped together. + # are grouped together. Note that this should only happen for + # constraints that involve a "given", external-to-the-database + # region (postprocessing.spatial_where_filtering); join + # constraints that need aggregates should have already been + # transformed in advance. select_builder.joins.fields[element.name]["region"] = ddl.Base64Region.union_aggregate( select_builder.joins.fields[element.name]["region"] ) have_aggregates = True + # Postprocessing spatial join constraints where at least one region's + # dimensions are being projected away will have already been turned + # into the kind of expression that sphgeom.Region.decodeOverlapsBase64 + # processes. We can just apply an aggregate function to these. Note + # that we don't do this to other constraints in order to minimize + # duplicate fetches of the same region blob. + for expr_name in postprocessing.spatial_expression_filtering: + select_builder.joins.special[expr_name] = sqlalchemy.cast( + sqlalchemy.func.aggregate_strings(select_builder.joins.special[expr_name], "|"), + type_=sqlalchemy.LargeBinary, + ) + have_aggregates = True # All dimension record fields are derived fields. for element_name, fields_for_element in projection_columns.dimension_fields.items(): diff --git a/python/lsst/daf/butler/direct_query_driver/_postprocessing.py b/python/lsst/daf/butler/direct_query_driver/_postprocessing.py index 244b11be6c..db577ec6a9 100644 --- a/python/lsst/daf/butler/direct_query_driver/_postprocessing.py +++ b/python/lsst/daf/butler/direct_query_driver/_postprocessing.py @@ -59,6 +59,7 @@ class Postprocessing: def __init__(self) -> None: self.spatial_join_filtering = [] self.spatial_where_filtering = [] + self.spatial_expression_filtering = [] self.check_validity_match_count: bool = False self._limit: int | None = None @@ -79,6 +80,12 @@ def __init__(self) -> None: non-overlap pair will be filtered out. """ + spatial_expression_filtering: list[str] + """The names of calculated columns that can be parsed by + `lsst.sphgeom.Region.decodeOverlapsBase64` into a `bool` or `None` that + indicates whether regions definitely overlap. + """ + check_validity_match_count: bool """If `True`, result rows will include a special column that counts the number of matching datasets in each collection for each data ID, and @@ -104,7 +111,9 @@ def limit(self, value: int | None) -> None: self._limit = value def __bool__(self) -> bool: - return bool(self.spatial_join_filtering or self.spatial_where_filtering) + return bool( + self.spatial_join_filtering or self.spatial_where_filtering or self.spatial_expression_filtering + ) def gather_columns_required(self, columns: qt.ColumnSet) -> None: """Add all columns required to perform postprocessing to the given @@ -198,8 +207,10 @@ def apply(self, rows: Iterable[sqlalchemy.Row]) -> Iterable[sqlalchemy.Row]: for row in rows: m = row._mapping # Skip rows where at least one couple of regions do not overlap. - if any(m[a].overlaps(m[b]) is False for a, b in joins) or any( - m[field].overlaps(region) is False for field, region in where + if ( + any(Region.decodeOverlapsBase64(m[c]) is False for c in self.spatial_expression_filtering) + or any(m[a].overlaps(m[b]) is False for a, b in joins) + or any(m[field].overlaps(region) is False for field, region in where) ): continue if self.check_validity_match_count and m[self.VALIDITY_MATCH_COUNT] > 1: diff --git a/python/lsst/daf/butler/direct_query_driver/_query_builder.py b/python/lsst/daf/butler/direct_query_driver/_query_builder.py index 3a2e8346f9..795c71916b 100644 --- a/python/lsst/daf/butler/direct_query_driver/_query_builder.py +++ b/python/lsst/daf/butler/direct_query_driver/_query_builder.py @@ -35,6 +35,7 @@ ) import dataclasses +import itertools from abc import ABC, abstractmethod from collections.abc import Iterable, Set from typing import TYPE_CHECKING, Literal, TypeVar, overload @@ -382,6 +383,9 @@ def apply_joins(self, driver: DirectQueryDriver) -> None: def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.OrderExpression]) -> None: # Docstring inherited. + driver.project_spatial_join_filtering( + self.projection_columns, self.postprocessing, [self._select_builder] + ) driver.apply_query_projection( self._select_builder, self.postprocessing, @@ -635,6 +639,11 @@ def apply_joins(self, driver: DirectQueryDriver) -> None: def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.OrderExpression]) -> None: # Docstring inherited. + driver.project_spatial_join_filtering( + self.projection_columns, + self.postprocessing, + itertools.chain.from_iterable(union_term.select_builders for union_term in self.union_terms), + ) for union_term in self.union_terms: for builder in union_term.select_builders: driver.apply_query_projection( diff --git a/python/lsst/daf/butler/direct_query_driver/_sql_builders.py b/python/lsst/daf/butler/direct_query_driver/_sql_builders.py index ca62ed47d4..1a537d5d28 100644 --- a/python/lsst/daf/butler/direct_query_driver/_sql_builders.py +++ b/python/lsst/daf/butler/direct_query_driver/_sql_builders.py @@ -422,6 +422,8 @@ def extract_columns( self.fields[element.name]["region"] = column_collection[ self.db.name_shrinker.shrink(columns.get_qualified_name(element.name, "region")) ] + for name in postprocessing.spatial_expression_filtering: + self.special[name] = column_collection[name] if postprocessing.check_validity_match_count: self.special[postprocessing.VALIDITY_MATCH_COUNT] = column_collection[ postprocessing.VALIDITY_MATCH_COUNT @@ -670,6 +672,8 @@ def make_table_spec( db.name_shrinker.shrink(columns.get_qualified_name(element.name, "region")) ) ) + for name in postprocessing.spatial_expression_filtering: + results.fields.add(ddl.FieldSpec(name, dtype=sqlalchemy.types.LargeBinary, nullable=True)) if not results.fields: results.fields.add( ddl.FieldSpec(name=SqlSelectBuilder.EMPTY_COLUMNS_NAME, dtype=SqlSelectBuilder.EMPTY_COLUMNS_TYPE) diff --git a/python/lsst/daf/butler/tests/butler_queries.py b/python/lsst/daf/butler/tests/butler_queries.py index 050816848e..56eaaca8c5 100644 --- a/python/lsst/daf/butler/tests/butler_queries.py +++ b/python/lsst/daf/butler/tests/butler_queries.py @@ -759,6 +759,16 @@ def test_spatial_overlaps(self) -> None: [1, 2, 3], has_postprocessing=True, ) + # Same as above, but with a materialization. + self.check_detector_records( + query.where( + _x.visit_detector_region.region.overlaps(_x.patch.region), + tract=0, + patch=4, + ).materialize().dimension_records("detector"), + [1, 2, 3], + has_postprocessing=True, + ) # Query for that patch's region and express the previous query as # a region-constraint instead of a spatial join. (patch_record,) = query.where(tract=0, patch=4).dimension_records("patch") @@ -777,6 +787,21 @@ def test_spatial_overlaps(self) -> None: ), ids=[1, 2, 3], ) + # Query for detectors where a patch/visit+detector overlap is + # satisfied, in the case where there are no rows with an overlap, + # but the union of the patch regions overlaps the union of the + # visit+detector regions. + self.check_detector_records( + query.where( + _x.visit_detector_region.region.overlaps(_x.patch.region), + _x.any( + _x.all(_x.tract == 1, _x.visit == 1), + _x.all(_x.tract == 0, _x.patch == 0, _x.visit == 2), + ), + ).dimension_records("detector"), + [], + has_postprocessing=True, + ) # Combine postprocessing with order_by and limit. self.check_detector_records( query.where(