Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-47947: fix consistency bug in aggregate spatial overlap postprocessing #1131

Merged
merged 4 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/DM-47947.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed a bug in which projections spatial-join queries (particularly those where the dimensions of the actual regions being compared are not in the query result rows) could return additional records where there actually was no overlap.
12 changes: 2 additions & 10 deletions python/lsst/daf/butler/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
"GUID",
)

import functools
import logging
import uuid
from base64 import b64decode, b64encode
Expand All @@ -60,7 +59,7 @@

import astropy.time
import sqlalchemy
from lsst.sphgeom import Region, UnionRegion
from lsst.sphgeom import Region
from lsst.utils.iteration import ensure_iterable
from sqlalchemy.dialects.postgresql import UUID

Expand Down Expand Up @@ -182,14 +181,7 @@ def process_bind_param(self, value: Region | None, dialect: sqlalchemy.engine.Di
def process_result_value(self, value: str | None, dialect: sqlalchemy.engine.Dialect) -> Region | None:
if value is None:
return None
return functools.reduce(
UnionRegion,
[
# For some reason super() doesn't work here!
Region.decode(Base64Bytes.process_result_value(self, union_member, dialect))
for union_member in value.split(":")
],
)
return Region.decodeBase64(value)

@property
def python_type(self) -> type[Region]:
Expand Down
72 changes: 67 additions & 5 deletions python/lsst/daf/butler/direct_query_driver/_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from .._collection_type import CollectionType
from .._dataset_type import DatasetType
from .._exceptions import InvalidQueryError
from ..dimensions import DataCoordinate, DataIdValue, DimensionGroup, DimensionUniverse
from ..dimensions import DataCoordinate, DataIdValue, DimensionElement, DimensionGroup, DimensionUniverse
from ..dimensions.record_cache import DimensionRecordCache
from ..queries import tree as qt
from ..queries.driver import (
Expand Down Expand Up @@ -388,6 +388,7 @@ def count(
select_builder = builder.finish_nested()
# Replace the columns of the query with just COUNT(*).
select_builder.columns = qt.ColumnSet(self._universe.empty)
select_builder.joins.special.clear()
count_func: sqlalchemy.ColumnElement[int] = sqlalchemy.func.count()
select_builder.joins.special["_ROWCOUNT"] = count_func
# Render and run the query.
Expand Down Expand Up @@ -655,6 +656,9 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis:
# it here.
postprocessing.spatial_join_filtering.extend(m_state.postprocessing.spatial_join_filtering)
postprocessing.spatial_where_filtering.extend(m_state.postprocessing.spatial_where_filtering)
postprocessing.spatial_expression_filtering.extend(
m_state.postprocessing.spatial_expression_filtering
)
# Add data coordinate uploads.
joins.data_coordinate_uploads.update(tree.data_coordinate_uploads)
# Add dataset_searches and filter out collections that don't have the
Expand Down Expand Up @@ -715,7 +719,7 @@ def _resolve_union_datasets(
searches : `list` [ `ResolvedDatasetSearch` ]
Resolved dataset searches for all union dataset types with these
dimensions. Each item in the list groups dataset types with the
same colletion search path.
same collection search path.
"""
# Gather the filtered collection search path for each union dataset
# type.
Expand Down Expand Up @@ -849,6 +853,48 @@ def apply_missing_dimension_joins(
joins_analysis.predicate.visit(SqlColumnVisitor(select_builder.joins, self))
)

def project_spatial_join_filtering(
self,
columns: qt.ColumnSet,
postprocessing: Postprocessing,
select_builders: Iterable[SqlSelectBuilder],
) -> None:
"""Transform spatial join postprocessing into expressions that can be
OR'd together via an aggregate function in a GROUP BY.

This only affects spatial join constraints involving region columns
whose dimensions are being projected away.

Parameters
----------
columns : `.queries.tree.ColumnSet`
Columns that will be included in the final query.
postprocessing : `Postprocessing`
Object that describes post-query processing; modified in place.
select_builders : `~collections.abc.Iterable` [ `SqlSelectBuilder` ]
SQL Builder objects to be modified in place.
"""
kept: list[tuple[DimensionElement, DimensionElement]] = []
for a, b in postprocessing.spatial_join_filtering:
if a.name not in columns.dimensions.elements or b.name not in columns.dimensions.elements:
expr_name = f"_{a}_OVERLAPS_{b}"
postprocessing.spatial_expression_filtering.append(expr_name)
for select_builder in select_builders:
expr = sqlalchemy.cast(
sqlalchemy.cast(
select_builder.joins.fields[a.name]["region"], type_=sqlalchemy.String
)
+ sqlalchemy.literal("&", type_=sqlalchemy.String)
+ sqlalchemy.cast(
select_builder.joins.fields[b.name]["region"], type_=sqlalchemy.String
),
type_=sqlalchemy.LargeBinary,
)
select_builder.joins.special[expr_name] = expr
else:
kept.append((a, b))
postprocessing.spatial_join_filtering = kept

def apply_query_projection(
self,
select_builder: SqlSelectBuilder,
Expand Down Expand Up @@ -938,8 +984,8 @@ def apply_query_projection(
# the data IDs for those regions are not wholly included in the
# results (i.e. we need to postprocess on
# visit_detector_region.region, but the output rows don't have
# detector, just visit - so we compute the union of the
# visit_detector region over all matched detectors).
# detector, just visit - so we pack the overlap expression into a
# blob via an aggregate function and interpret it later).
if postprocessing.check_validity_match_count:
if needs_validity_match_count:
select_builder.joins.special[postprocessing.VALIDITY_MATCH_COUNT] = (
Expand All @@ -960,11 +1006,27 @@ def apply_query_projection(
# might be collapsing the dimensions of the postprocessing
# regions. When that happens, we want to apply an aggregate
# function to them that computes the union of the regions that
# are grouped together.
# are grouped together. Note that this should only happen for
# constraints that involve a "given", external-to-the-database
# region (postprocessing.spatial_where_filtering); join
# constraints that need aggregates should have already been
# transformed in advance.
select_builder.joins.fields[element.name]["region"] = ddl.Base64Region.union_aggregate(
select_builder.joins.fields[element.name]["region"]
)
have_aggregates = True
# Postprocessing spatial join constraints where at least one region's
# dimensions are being projected away will have already been turned
# into the kind of expression that sphgeom.Region.decodeOverlapsBase64
# processes. We can just apply an aggregate function to these. Note
# that we don't do this to other constraints in order to minimize
# duplicate fetches of the same region blob.
for expr_name in postprocessing.spatial_expression_filtering:
select_builder.joins.special[expr_name] = sqlalchemy.cast(
sqlalchemy.func.aggregate_strings(select_builder.joins.special[expr_name], "|"),
type_=sqlalchemy.LargeBinary,
)
have_aggregates = True

# All dimension record fields are derived fields.
for element_name, fields_for_element in projection_columns.dimension_fields.items():
Expand Down
20 changes: 16 additions & 4 deletions python/lsst/daf/butler/direct_query_driver/_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from typing import TYPE_CHECKING, ClassVar

import sqlalchemy
from lsst.sphgeom import DISJOINT, Region
from lsst.sphgeom import Region

from .._exceptions import CalibrationLookupError
from ..queries import tree as qt
Expand All @@ -59,6 +59,7 @@ class Postprocessing:
def __init__(self) -> None:
self.spatial_join_filtering = []
self.spatial_where_filtering = []
self.spatial_expression_filtering = []
self.check_validity_match_count: bool = False
self._limit: int | None = None

Expand All @@ -79,6 +80,12 @@ def __init__(self) -> None:
non-overlap pair will be filtered out.
"""

spatial_expression_filtering: list[str]
"""The names of calculated columns that can be parsed by
`lsst.sphgeom.Region.decodeOverlapsBase64` into a `bool` or `None` that
indicates whether regions definitely overlap.
"""

check_validity_match_count: bool
"""If `True`, result rows will include a special column that counts the
number of matching datasets in each collection for each data ID, and
Expand All @@ -104,7 +111,9 @@ def limit(self, value: int | None) -> None:
self._limit = value

def __bool__(self) -> bool:
return bool(self.spatial_join_filtering or self.spatial_where_filtering)
return bool(
self.spatial_join_filtering or self.spatial_where_filtering or self.spatial_expression_filtering
)

def gather_columns_required(self, columns: qt.ColumnSet) -> None:
"""Add all columns required to perform postprocessing to the given
Expand Down Expand Up @@ -197,8 +206,11 @@ def apply(self, rows: Iterable[sqlalchemy.Row]) -> Iterable[sqlalchemy.Row]:

for row in rows:
m = row._mapping
if any(m[a].relate(m[b]) & DISJOINT for a, b in joins) or any(
m[field].relate(region) & DISJOINT for field, region in where
# Skip rows where at least one couple of regions do not overlap.
if (
any(Region.decodeOverlapsBase64(m[c]) is False for c in self.spatial_expression_filtering)
or any(m[a].overlaps(m[b]) is False for a, b in joins)
or any(m[field].overlaps(region) is False for field, region in where)
):
continue
if self.check_validity_match_count and m[self.VALIDITY_MATCH_COUNT] > 1:
Expand Down
9 changes: 9 additions & 0 deletions python/lsst/daf/butler/direct_query_driver/_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)

import dataclasses
import itertools
from abc import ABC, abstractmethod
from collections.abc import Iterable, Set
from typing import TYPE_CHECKING, Literal, TypeVar, overload
Expand Down Expand Up @@ -382,6 +383,9 @@

def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.OrderExpression]) -> None:
# Docstring inherited.
driver.project_spatial_join_filtering(
self.projection_columns, self.postprocessing, [self._select_builder]
)
driver.apply_query_projection(
self._select_builder,
self.postprocessing,
Expand Down Expand Up @@ -635,6 +639,11 @@

def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.OrderExpression]) -> None:
# Docstring inherited.
driver.project_spatial_join_filtering(

Check warning on line 642 in python/lsst/daf/butler/direct_query_driver/_query_builder.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_query_builder.py#L642

Added line #L642 was not covered by tests
self.projection_columns,
self.postprocessing,
itertools.chain.from_iterable(union_term.select_builders for union_term in self.union_terms),
)
for union_term in self.union_terms:
for builder in union_term.select_builders:
driver.apply_query_projection(
Expand Down
8 changes: 6 additions & 2 deletions python/lsst/daf/butler/direct_query_driver/_sql_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def join(self, other: SqlJoinsBuilder) -> SqlSelectBuilder:
self.joins.join(other)
return self

def into_from_builder(
def into_joins_builder(
self, cte: bool = False, force: bool = False, *, postprocessing: Postprocessing | None
) -> SqlJoinsBuilder:
"""Convert this builder into a `SqlJoinsBuilder`, nesting it in a
Expand Down Expand Up @@ -265,7 +265,7 @@ def nested(
object.
"""
return SqlSelectBuilder(
self.into_from_builder(cte=cte, force=force, postprocessing=postprocessing), columns=self.columns
self.into_joins_builder(cte=cte, force=force, postprocessing=postprocessing), columns=self.columns
)

def union_subquery(
Expand Down Expand Up @@ -422,6 +422,8 @@ def extract_columns(
self.fields[element.name]["region"] = column_collection[
self.db.name_shrinker.shrink(columns.get_qualified_name(element.name, "region"))
]
for name in postprocessing.spatial_expression_filtering:
self.special[name] = column_collection[name]
if postprocessing.check_validity_match_count:
self.special[postprocessing.VALIDITY_MATCH_COUNT] = column_collection[
postprocessing.VALIDITY_MATCH_COUNT
Expand Down Expand Up @@ -670,6 +672,8 @@ def make_table_spec(
db.name_shrinker.shrink(columns.get_qualified_name(element.name, "region"))
)
)
for name in postprocessing.spatial_expression_filtering:
results.fields.add(ddl.FieldSpec(name, dtype=sqlalchemy.types.LargeBinary, nullable=True))
if not results.fields:
results.fields.add(
ddl.FieldSpec(name=SqlSelectBuilder.EMPTY_COLUMNS_NAME, dtype=SqlSelectBuilder.EMPTY_COLUMNS_TYPE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1453,9 +1453,9 @@ def make_joins_builder(
# Need a UNION subquery.
return tags_builder.union_subquery([calibs_builder])
else:
return tags_builder.into_from_builder(postprocessing=None)
return tags_builder.into_joins_builder(postprocessing=None)
elif calibs_builder is not None:
return calibs_builder.into_from_builder(postprocessing=None)
return calibs_builder.into_joins_builder(postprocessing=None)
else:
raise AssertionError("Branch should be unreachable.")

Expand Down
6 changes: 3 additions & 3 deletions python/lsst/daf/butler/registry/dimensions/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def make_joins_builder(self, element: DimensionElement, fields: Set[str]) -> Sql
self.make_joins_builder(element.implied_union_target, fields),
columns=qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(),
distinct=True,
).into_from_builder(postprocessing=None)
).into_joins_builder(postprocessing=None)
if not element.has_own_table:
raise NotImplementedError(f"Cannot join dimension element {element} with no table.")
table = self._tables[element.name]
Expand Down Expand Up @@ -1082,7 +1082,7 @@ def visit_spatial_constraint(
self.builder.join(
joins_builder.to_select_builder(
qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True
).into_from_builder(postprocessing=None)
).into_joins_builder(postprocessing=None)
)
# Short circuit here since the SQL WHERE clause has already
# been embedded in the subquery.
Expand Down Expand Up @@ -1147,7 +1147,7 @@ def visit_spatial_join(
qt.ColumnSet(a.minimal_group | b.minimal_group).drop_implied_dimension_keys(),
distinct=True,
)
.into_from_builder(postprocessing=None)
.into_joins_builder(postprocessing=None)
)
# In both cases we add postprocessing to check that the regions
# really do overlap, since overlapping the same common skypix
Expand Down
27 changes: 27 additions & 0 deletions python/lsst/daf/butler/tests/butler_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,18 @@ def test_spatial_overlaps(self) -> None:
[1, 2, 3],
has_postprocessing=True,
)
# Same as above, but with a materialization.
self.check_detector_records(
query.where(
_x.visit_detector_region.region.overlaps(_x.patch.region),
tract=0,
patch=4,
)
.materialize(dimensions=["detector"])
.dimension_records("detector"),
[1, 2, 3],
has_postprocessing=True,
)
# Query for that patch's region and express the previous query as
# a region-constraint instead of a spatial join.
(patch_record,) = query.where(tract=0, patch=4).dimension_records("patch")
Expand All @@ -777,6 +789,21 @@ def test_spatial_overlaps(self) -> None:
),
ids=[1, 2, 3],
)
# Query for detectors where a patch/visit+detector overlap is
# satisfied, in the case where there are no rows with an overlap,
# but the union of the patch regions overlaps the union of the
# visit+detector regions.
self.check_detector_records(
query.where(
_x.visit_detector_region.region.overlaps(_x.patch.region),
_x.any(
_x.all(_x.tract == 1, _x.visit == 1),
_x.all(_x.tract == 0, _x.patch == 0, _x.visit == 2),
),
).dimension_records("detector"),
[],
has_postprocessing=True,
)
# Combine postprocessing with order_by and limit.
self.check_detector_records(
query.where(
Expand Down
Loading