Skip to content

Commit

Permalink
Add Query option to suppress DISTINCT in skypix overlaps.
Browse files Browse the repository at this point in the history
This is a non-public API for now, solely for graph builder use.
  • Loading branch information
andy-slac committed Jan 28, 2025
1 parent 1eab6b0 commit 5903c60
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 10 deletions.
20 changes: 16 additions & 4 deletions python/lsst/daf/butler/direct_query_driver/_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def execute(self, result_spec: ResultSpec, tree: qt.QueryTree) -> Iterator[Resul
final_columns=result_spec.get_result_columns(),
order_by=result_spec.order_by,
find_first_dataset=result_spec.find_first_dataset,
allow_duplicate_overlaps=result_spec.allow_duplicate_overlaps,
)
sql_select, sql_columns = builder.finish_select()
if result_spec.order_by:
Expand Down Expand Up @@ -290,12 +291,15 @@ def materialize(
tree: qt.QueryTree,
dimensions: DimensionGroup,
datasets: frozenset[str],
allow_duplicate_overlaps: bool = False,
key: qt.MaterializationKey | None = None,
) -> qt.MaterializationKey:
# Docstring inherited.
if self._exit_stack is None:
raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.")
plan = self.build_query(tree, qt.ColumnSet(dimensions))
plan = self.build_query(
tree, qt.ColumnSet(dimensions), allow_duplicate_overlaps=allow_duplicate_overlaps
)
# Current implementation ignores 'datasets' aside from remembering
# them, because figuring out what to put in the temporary table for
# them is tricky, especially if calibration collections are involved.
Expand Down Expand Up @@ -403,7 +407,7 @@ def count(

def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool:
# Docstring inherited.
builder = self.build_query(tree, qt.ColumnSet(tree.dimensions))
builder = self.build_query(tree, qt.ColumnSet(tree.dimensions), allow_duplicate_overlaps=True)
if not all(d.collection_records for d in builder.joins_analysis.datasets.values()):
return False
if not execute:
Expand Down Expand Up @@ -449,6 +453,7 @@ def build_query(
order_by: Iterable[qt.OrderExpression] = (),
find_first_dataset: str | qt.AnyDatasetType | None = None,
analyze_only: bool = False,
allow_duplicate_overlaps: bool = False,
) -> QueryBuilder:
"""Convert a query description into a nearly-complete builder object
for the SQL version of that query.
Expand All @@ -472,6 +477,9 @@ def build_query(
builder, but do not call methods that build its SQL form. This can
be useful for obtaining diagnostic information about the query that
would be generated.
allow_duplicate_overlaps : `bool`, optional
If set to `True` then query will be allowed to generate
non-distinct rows for spatial overlaps.
Returns
-------
Expand Down Expand Up @@ -544,7 +552,7 @@ def build_query(
# SqlSelectBuilder and Postprocessing with spatial/temporal constraints
# potentially transformed by the dimensions manager (but none of the
# rest of the analysis reflected in that SqlSelectBuilder).
query_tree_analysis = self._analyze_query_tree(tree)
query_tree_analysis = self._analyze_query_tree(tree, allow_duplicate_overlaps)
# The "projection" columns differ from the final columns by not
# omitting any dimension keys (this keeps queries for different result
# types more similar during construction), including any columns needed
Expand Down Expand Up @@ -591,7 +599,7 @@ def build_query(
builder.apply_find_first(self)
return builder

def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis:
def _analyze_query_tree(self, tree: qt.QueryTree, allow_duplicate_overlaps: bool) -> QueryTreeAnalysis:
"""Analyze a `.queries.tree.QueryTree` as the first step in building
a SQL query.
Expand All @@ -605,6 +613,9 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis:
tree_analysis : `QueryTreeAnalysis`
Struct containing additional information need to build the joins
stage of a query.
allow_duplicate_overlaps : `bool`, optional
If set to `True` then query will be allowed to generate
non-distinct rows for spatial overlaps.
Notes
-----
Expand Down Expand Up @@ -634,6 +645,7 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis:
tree.predicate,
tree.get_joined_dimension_groups(),
collection_analysis.calibration_dataset_types,
allow_duplicate_overlaps,
)
# Extract the data ID implied by the predicate; we can use the governor
# dimensions in that to constrain the collections we search for
Expand Down
22 changes: 19 additions & 3 deletions python/lsst/daf/butler/queries/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def __init__(self, driver: QueryDriver, tree: QueryTree | None = None):
tree = make_identity_query_tree(driver.universe)
super().__init__(driver, tree)

# If ``_allow_duplicate_overlaps`` is set to `True` then query will be
# allowed to generate non-distinct rows for spatial overlaps. This is
# not a part of public API for now, to be used by graph builder as
# optimization.
self._allow_duplicate_overlaps: bool = False

@property
def constraint_dataset_types(self) -> Set[str]:
"""The names of all dataset types joined into the query.
Expand Down Expand Up @@ -218,7 +224,11 @@ def data_ids(
dimensions = self._driver.universe.conform(dimensions)
if not dimensions <= self._tree.dimensions:
tree = tree.join_dimensions(dimensions)
result_spec = DataCoordinateResultSpec(dimensions=dimensions, include_dimension_records=False)
result_spec = DataCoordinateResultSpec(
dimensions=dimensions,
include_dimension_records=False,
allow_duplicate_overlaps=self._allow_duplicate_overlaps,
)
return DataCoordinateQueryResults(self._driver, tree, result_spec)

def datasets(
Expand Down Expand Up @@ -284,6 +294,7 @@ def datasets(
storage_class_name=storage_class_name,
include_dimension_records=False,
find_first=find_first,
allow_duplicate_overlaps=self._allow_duplicate_overlaps,
)
return DatasetRefQueryResults(self._driver, tree=query._tree, spec=spec)

Expand All @@ -308,7 +319,9 @@ def dimension_records(self, element: str) -> DimensionRecordQueryResults:
tree = self._tree
if element not in tree.dimensions.elements:
tree = tree.join_dimensions(self._driver.universe[element].minimal_group)
result_spec = DimensionRecordResultSpec(element=self._driver.universe[element])
result_spec = DimensionRecordResultSpec(
element=self._driver.universe[element], allow_duplicate_overlaps=self._allow_duplicate_overlaps
)
return DimensionRecordQueryResults(self._driver, tree, result_spec)

def general(
Expand Down Expand Up @@ -445,6 +458,7 @@ def general(
dimension_fields=dimension_fields_dict,
dataset_fields=dataset_fields_dict,
find_first=find_first,
allow_duplicate_overlaps=self._allow_duplicate_overlaps,
)
return GeneralQueryResults(self._driver, tree=tree, spec=result_spec)

Expand Down Expand Up @@ -495,7 +509,9 @@ def materialize(
dimensions = self._tree.dimensions
else:
dimensions = self._driver.universe.conform(dimensions)
key = self._driver.materialize(self._tree, dimensions, datasets)
key = self._driver.materialize(
self._tree, dimensions, datasets, allow_duplicate_overlaps=self._allow_duplicate_overlaps
)
tree = make_identity_query_tree(self._driver.universe).join_materialization(
key, dimensions=dimensions
)
Expand Down
4 changes: 4 additions & 0 deletions python/lsst/daf/butler/queries/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def materialize(
tree: QueryTree,
dimensions: DimensionGroup,
datasets: frozenset[str],
allow_duplicate_overlaps: bool = False,
) -> MaterializationKey:
"""Execute a query tree, saving results to temporary storage for use
in later queries.
Expand All @@ -222,6 +223,9 @@ def materialize(
datasets : `frozenset` [ `str` ]
Names of dataset types whose ID columns may be materialized. It
is implementation-defined whether they actually are.
allow_duplicate_overlaps : `bool`, optional
If set to `True` then query will be allowed to generate
non-distinct rows for spatial overlaps.
Returns
-------
Expand Down
5 changes: 5 additions & 0 deletions python/lsst/daf/butler/queries/result_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ class ResultSpecBase(pydantic.BaseModel, ABC):
limit: int | None = None
"""Maximum number of rows to return, or `None` for no bound."""

allow_duplicate_overlaps: bool = False
"""If set to True the queries are allowed to returnd duplicate rows for
spatial overlaps.
"""

def validate_tree(self, tree: QueryTree) -> None:
"""Check that this result object is consistent with a query tree.
Expand Down
10 changes: 7 additions & 3 deletions python/lsst/daf/butler/registry/dimensions/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,10 @@ def process_query_overlaps(
predicate: qt.Predicate,
join_operands: Iterable[DimensionGroup],
calibration_dataset_types: Set[str | qt.AnyDatasetType],
allow_duplicates: bool = False,
) -> tuple[qt.Predicate, SqlSelectBuilder, Postprocessing]:
overlaps_visitor = _CommonSkyPixMediatedOverlapsVisitor(
self._db, dimensions, calibration_dataset_types, self._overlap_tables
self._db, dimensions, calibration_dataset_types, self._overlap_tables, allow_duplicates
)
new_predicate = overlaps_visitor.run(predicate, join_operands)
return new_predicate, overlaps_visitor.builder, overlaps_visitor.postprocessing
Expand Down Expand Up @@ -1025,13 +1026,15 @@ def __init__(
dimensions: DimensionGroup,
calibration_dataset_types: Set[str | qt.AnyDatasetType],
overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]],
allow_duplicates: bool,
):
super().__init__(dimensions, calibration_dataset_types)
self.builder: SqlSelectBuilder = SqlJoinsBuilder(db=db).to_select_builder(qt.ColumnSet(dimensions))
self.postprocessing = Postprocessing()
self.common_skypix = dimensions.universe.commonSkyPix
self.overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]] = overlap_tables
self.common_skypix_overlaps_done: set[DatabaseDimensionElement] = set()
self.allow_duplicates = allow_duplicates

def visit_spatial_constraint(
self,
Expand Down Expand Up @@ -1081,7 +1084,8 @@ def visit_spatial_constraint(
joins_builder.where(sqlalchemy.or_(*sql_where_or))
self.builder.join(
joins_builder.to_select_builder(
qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True
qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(),
distinct=not self.allow_duplicates,
).into_joins_builder(postprocessing=None)
)
# Short circuit here since the SQL WHERE clause has already
Expand Down Expand Up @@ -1145,7 +1149,7 @@ def visit_spatial_join(
.join(self._make_common_skypix_overlap_joins_builder(b))
.to_select_builder(
qt.ColumnSet(a.minimal_group | b.minimal_group).drop_implied_dimension_keys(),
distinct=True,
distinct=not self.allow_duplicates,
)
.into_joins_builder(postprocessing=None)
)
Expand Down
4 changes: 4 additions & 0 deletions python/lsst/daf/butler/registry/interfaces/_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def process_query_overlaps(
predicate: Predicate,
join_operands: Iterable[DimensionGroup],
calibration_dataset_types: Set[str | AnyDatasetType],
allow_duplicates: bool = False,
) -> tuple[Predicate, SqlSelectBuilder, Postprocessing]:
"""Process a query's WHERE predicate and dimensions to handle spatial
and temporal overlaps.
Expand All @@ -424,6 +425,9 @@ def process_query_overlaps(
`..queries.tree.AnyDatasetType` ]
The names of dataset types that have been joined into the query via
a search that includes at least one calibration collection.
allow_duplicates : `bool`
If set to `True` then query will be allowed to return non-distinct
rows.
Returns
-------
Expand Down
2 changes: 2 additions & 0 deletions python/lsst/daf/butler/remote_butler/_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def materialize(
tree: QueryTree,
dimensions: DimensionGroup,
datasets: frozenset[str],
allow_duplicate_overlaps: bool = False,
) -> MaterializationKey:
key = uuid4()
self._stored_query_inputs.append(
Expand All @@ -171,6 +172,7 @@ def materialize(
tree=SerializedQueryTree(tree.model_copy(deep=True)),
dimensions=dimensions.to_simple(),
datasets=datasets,
allow_duplicate_overlaps=allow_duplicate_overlaps,
),
)
return key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def _get_query_context(factory: Factory, query: QueryInputs) -> Iterator[_QueryC
DimensionGroup.from_simple(input.dimensions, butler.dimensions),
frozenset(input.datasets),
key=input.key,
allow_duplicate_overlaps=input.allow_duplicate_overlaps,
)
elif input.type == "upload":
driver.upload_data_coordinates(
Expand Down
1 change: 1 addition & 0 deletions python/lsst/daf/butler/remote_butler/server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ class MaterializedQuery(pydantic.BaseModel):
tree: SerializedQueryTree
dimensions: SerializedDimensionGroup
datasets: list[str]
allow_duplicate_overlaps: bool = False


class DataCoordinateUpload(pydantic.BaseModel):
Expand Down
20 changes: 20 additions & 0 deletions python/lsst/daf/butler/tests/butler_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,26 @@ def test_dataset_constrained_record_query(self) -> None:
doomed=True,
)

def test_duplicate_overlaps(self) -> None:
"""Test for query option that enables duplicate rows in queries that
use skypix overalps.
"""
butler = self.make_butler("base.yaml", "spatial.yaml")
butler.registry.defaults = RegistryDefaults(instrument="Cam1", skymap="SkyMap1")
with butler.query() as query:

data_ids = list(query.data_ids(["visit", "detector", "patch"]).where(visit=1, detector=1))
self.assertCountEqual(
[(data_id["tract"], data_id["patch"]) for data_id in data_ids], [(0, 0), (0, 2)]
)

query._allow_duplicate_overlaps = True
data_ids = list(query.data_ids(["visit", "detector", "patch"]).where(visit=1, detector=1))
self.assertCountEqual(
[(data_id["tract"], data_id["patch"]) for data_id in data_ids],
[(0, 0), (0, 0), (0, 2), (0, 2)],
)

def test_spatial_overlaps(self) -> None:
"""Test queries for dimension records with spatial overlaps.
Expand Down
1 change: 1 addition & 0 deletions tests/test_query_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def materialize(
tree: qt.QueryTree,
dimensions: DimensionGroup,
datasets: frozenset[str],
allow_duplicate_overlaps: bool = False,
) -> qd.MaterializationKey:
key = uuid.uuid4()
self.materializations[key] = (tree, dimensions, datasets)
Expand Down

0 comments on commit 5903c60

Please sign in to comment.