Skip to content

Commit

Permalink
Fix consistency in queries with aggregated spatial postprocessing.
Browse files Browse the repository at this point in the history
Instead of testing (a UNION b) OVERLAPS (c UNION d) in these queries,
we now correctly test (a OVERLAPS c) OR (b OVERLAPS d).
  • Loading branch information
TallJimbo committed Dec 10, 2024
1 parent fd56d7f commit 4a74141
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 7 deletions.
70 changes: 66 additions & 4 deletions python/lsst/daf/butler/direct_query_driver/_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from .._collection_type import CollectionType
from .._dataset_type import DatasetType
from .._exceptions import InvalidQueryError
from ..dimensions import DataCoordinate, DataIdValue, DimensionGroup, DimensionUniverse
from ..dimensions import DataCoordinate, DataIdValue, DimensionElement, DimensionGroup, DimensionUniverse
from ..dimensions.record_cache import DimensionRecordCache
from ..queries import tree as qt
from ..queries.driver import (
Expand Down Expand Up @@ -388,6 +388,7 @@ def count(
select_builder = builder.finish_nested()
# Replace the columns of the query with just COUNT(*).
select_builder.columns = qt.ColumnSet(self._universe.empty)
select_builder.joins.special.clear()
count_func: sqlalchemy.ColumnElement[int] = sqlalchemy.func.count()
select_builder.joins.special["_ROWCOUNT"] = count_func
# Render and run the query.
Expand Down Expand Up @@ -655,6 +656,9 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis:
# it here.
postprocessing.spatial_join_filtering.extend(m_state.postprocessing.spatial_join_filtering)
postprocessing.spatial_where_filtering.extend(m_state.postprocessing.spatial_where_filtering)
postprocessing.spatial_expression_filtering.extend(
m_state.postprocessing.spatial_expression_filtering
)
# Add data coordinate uploads.
joins.data_coordinate_uploads.update(tree.data_coordinate_uploads)
# Add dataset_searches and filter out collections that don't have the
Expand Down Expand Up @@ -849,6 +853,48 @@ def apply_missing_dimension_joins(
joins_analysis.predicate.visit(SqlColumnVisitor(select_builder.joins, self))
)

def project_spatial_join_filtering(
self,
columns: qt.ColumnSet,
postprocessing: Postprocessing,
select_builders: Iterable[SqlSelectBuilder],
) -> None:
"""Transform spatial join postprocessing into expressions that can be
OR'd together via an aggregate function in a GROUP BY.
This only affects spatial join constraints involving region columns
whose dimensions are being projected away.
Parameters
----------
columns : `.queries.tree.ColumnSet`
Columns that will be included in the final query.
postprocessing : `Postprocessing`
Object that describes post-query processing; modified in place.
select_builders : `~collections.abc.Iterable` [ `SqlSelectBuilder` ]
SQL Builder objects to be modified in place.
"""
kept: list[tuple[DimensionElement, DimensionElement]] = []
for a, b in postprocessing.spatial_join_filtering:
if a.name not in columns.dimensions.elements or b.name not in columns.dimensions.elements:
expr_name = f"_{a}_OVERLAPS_{b}"
postprocessing.spatial_expression_filtering.append(expr_name)
for select_builder in select_builders:
expr = sqlalchemy.cast(
sqlalchemy.cast(
select_builder.joins.fields[a.name]["region"], type_=sqlalchemy.String
)
+ sqlalchemy.literal("&", type_=sqlalchemy.String)
+ sqlalchemy.cast(
select_builder.joins.fields[b.name]["region"], type_=sqlalchemy.String
),
type_=sqlalchemy.LargeBinary,
)
select_builder.joins.special[expr_name] = expr
else:
kept.append((a, b))
postprocessing.spatial_join_filtering = kept

def apply_query_projection(
self,
select_builder: SqlSelectBuilder,
Expand Down Expand Up @@ -938,8 +984,8 @@ def apply_query_projection(
# the data IDs for those regions are not wholly included in the
# results (i.e. we need to postprocess on
# visit_detector_region.region, but the output rows don't have
# detector, just visit - so we compute the union of the
# visit_detector region over all matched detectors).
# detector, just visit - so we pack the overlap expression into a
# blob via an aggregate function and interpret it later).
if postprocessing.check_validity_match_count:
if needs_validity_match_count:
select_builder.joins.special[postprocessing.VALIDITY_MATCH_COUNT] = (
Expand All @@ -960,11 +1006,27 @@ def apply_query_projection(
# might be collapsing the dimensions of the postprocessing
# regions. When that happens, we want to apply an aggregate
# function to them that computes the union of the regions that
# are grouped together.
# are grouped together. Note that this should only happen for
# constraints that involve a "given", external-to-the-database
# region (postprocessing.spatial_where_filtering); join
# constraints that need aggregates should have already been
# transformed in advance.
select_builder.joins.fields[element.name]["region"] = ddl.Base64Region.union_aggregate(
select_builder.joins.fields[element.name]["region"]
)
have_aggregates = True
# Postprocessing spatial join constraints where at least one region's
# dimensions are being projected away will have already been turned
# into the kind of expression that sphgeom.Region.decodeOverlapsBase64
# processes. We can just apply an aggregate function to these. Note
# that we don't do this to other constraints in order to minimize
# duplicate fetches of the same region blob.
for expr_name in postprocessing.spatial_expression_filtering:
select_builder.joins.special[expr_name] = sqlalchemy.cast(
sqlalchemy.func.aggregate_strings(select_builder.joins.special[expr_name], "|"),
type_=sqlalchemy.LargeBinary,
)
have_aggregates = True

# All dimension record fields are derived fields.
for element_name, fields_for_element in projection_columns.dimension_fields.items():
Expand Down
17 changes: 14 additions & 3 deletions python/lsst/daf/butler/direct_query_driver/_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Postprocessing:
def __init__(self) -> None:
self.spatial_join_filtering = []
self.spatial_where_filtering = []
self.spatial_expression_filtering = []
self.check_validity_match_count: bool = False
self._limit: int | None = None

Expand All @@ -79,6 +80,12 @@ def __init__(self) -> None:
non-overlap pair will be filtered out.
"""

spatial_expression_filtering: list[str]
"""The names of calculated columns that can be parsed by
`lsst.sphgeom.Region.decodeOverlapsBase64` into a `bool` or `None` that
indicates whether regions definitely overlap.
"""

check_validity_match_count: bool
"""If `True`, result rows will include a special column that counts the
number of matching datasets in each collection for each data ID, and
Expand All @@ -104,7 +111,9 @@ def limit(self, value: int | None) -> None:
self._limit = value

def __bool__(self) -> bool:
return bool(self.spatial_join_filtering or self.spatial_where_filtering)
return bool(
self.spatial_join_filtering or self.spatial_where_filtering or self.spatial_expression_filtering
)

def gather_columns_required(self, columns: qt.ColumnSet) -> None:
"""Add all columns required to perform postprocessing to the given
Expand Down Expand Up @@ -198,8 +207,10 @@ def apply(self, rows: Iterable[sqlalchemy.Row]) -> Iterable[sqlalchemy.Row]:
for row in rows:
m = row._mapping
# Skip rows where at least one couple of regions do not overlap.
if any(m[a].overlaps(m[b]) is False for a, b in joins) or any(
m[field].overlaps(region) is False for field, region in where
if (
any(Region.decodeOverlapsBase64(m[c]) is False for c in self.spatial_expression_filtering)
or any(m[a].overlaps(m[b]) is False for a, b in joins)
or any(m[field].overlaps(region) is False for field, region in where)
):
continue
if self.check_validity_match_count and m[self.VALIDITY_MATCH_COUNT] > 1:
Expand Down
9 changes: 9 additions & 0 deletions python/lsst/daf/butler/direct_query_driver/_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)

import dataclasses
import itertools
from abc import ABC, abstractmethod
from collections.abc import Iterable, Set
from typing import TYPE_CHECKING, Literal, TypeVar, overload
Expand Down Expand Up @@ -382,6 +383,9 @@ def apply_joins(self, driver: DirectQueryDriver) -> None:

def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.OrderExpression]) -> None:
# Docstring inherited.
driver.project_spatial_join_filtering(
self.projection_columns, self.postprocessing, [self._select_builder]
)
driver.apply_query_projection(
self._select_builder,
self.postprocessing,
Expand Down Expand Up @@ -635,6 +639,11 @@ def apply_joins(self, driver: DirectQueryDriver) -> None:

def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.OrderExpression]) -> None:
# Docstring inherited.
driver.project_spatial_join_filtering(
self.projection_columns,
self.postprocessing,
itertools.chain.from_iterable(union_term.select_builders for union_term in self.union_terms),
)
for union_term in self.union_terms:
for builder in union_term.select_builders:
driver.apply_query_projection(
Expand Down
4 changes: 4 additions & 0 deletions python/lsst/daf/butler/direct_query_driver/_sql_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ def extract_columns(
self.fields[element.name]["region"] = column_collection[
self.db.name_shrinker.shrink(columns.get_qualified_name(element.name, "region"))
]
for name in postprocessing.spatial_expression_filtering:
self.special[name] = column_collection[name]
if postprocessing.check_validity_match_count:
self.special[postprocessing.VALIDITY_MATCH_COUNT] = column_collection[
postprocessing.VALIDITY_MATCH_COUNT
Expand Down Expand Up @@ -670,6 +672,8 @@ def make_table_spec(
db.name_shrinker.shrink(columns.get_qualified_name(element.name, "region"))
)
)
for name in postprocessing.spatial_expression_filtering:
results.fields.add(ddl.FieldSpec(name, dtype=sqlalchemy.types.LargeBinary, nullable=True))
if not results.fields:
results.fields.add(
ddl.FieldSpec(name=SqlSelectBuilder.EMPTY_COLUMNS_NAME, dtype=SqlSelectBuilder.EMPTY_COLUMNS_TYPE)
Expand Down
25 changes: 25 additions & 0 deletions python/lsst/daf/butler/tests/butler_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,16 @@ def test_spatial_overlaps(self) -> None:
[1, 2, 3],
has_postprocessing=True,
)
# Same as above, but with a materialization.
self.check_detector_records(
query.where(
_x.visit_detector_region.region.overlaps(_x.patch.region),
tract=0,
patch=4,
).materialize().dimension_records("detector"),
[1, 2, 3],
has_postprocessing=True,
)
# Query for that patch's region and express the previous query as
# a region-constraint instead of a spatial join.
(patch_record,) = query.where(tract=0, patch=4).dimension_records("patch")
Expand All @@ -777,6 +787,21 @@ def test_spatial_overlaps(self) -> None:
),
ids=[1, 2, 3],
)
# Query for detectors where a patch/visit+detector overlap is
# satisfied, in the case where there are no rows with an overlap,
# but the union of the patch regions overlaps the union of the
# visit+detector regions.
self.check_detector_records(
query.where(
_x.visit_detector_region.region.overlaps(_x.patch.region),
_x.any(
_x.all(_x.tract == 1, _x.visit == 1),
_x.all(_x.tract == 0, _x.patch == 0, _x.visit == 2),
),
).dimension_records("detector"),
[],
has_postprocessing=True,
)
# Combine postprocessing with order_by and limit.
self.check_detector_records(
query.where(
Expand Down

0 comments on commit 4a74141

Please sign in to comment.