diff --git a/doc/changes/DM-46401.bugfix.md b/doc/changes/DM-46401.bugfix.md new file mode 100644 index 0000000000..0c92351d4b --- /dev/null +++ b/doc/changes/DM-46401.bugfix.md @@ -0,0 +1 @@ +Fix support for multiple-instrument (and multiple-skymap) `where` expressions in the new query system. diff --git a/python/lsst/daf/butler/direct_query_driver/_driver.py b/python/lsst/daf/butler/direct_query_driver/_driver.py index 4ccd769fad..bfc85ac143 100644 --- a/python/lsst/daf/butler/direct_query_driver/_driver.py +++ b/python/lsst/daf/butler/direct_query_driver/_driver.py @@ -937,7 +937,7 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> tuple[QueryJoinsPlan, Query where_governors: set[str] = set() result.predicate.gather_governors(where_governors) for governor in where_governors: - if governor not in result.constraint_data_id: + if governor not in result.constraint_data_id and governor not in result.governors_referenced: if governor in self._default_data_id.dimensions: result.constraint_data_id[governor] = self._default_data_id[governor] else: diff --git a/python/lsst/daf/butler/direct_query_driver/_query_plan.py b/python/lsst/daf/butler/direct_query_driver/_query_plan.py index 60f8a8c05c..ced6c1131c 100644 --- a/python/lsst/daf/butler/direct_query_driver/_query_plan.py +++ b/python/lsst/daf/butler/direct_query_driver/_query_plan.py @@ -115,12 +115,19 @@ class QueryJoinsPlan: rows. """ + governors_referenced: set[str] = dataclasses.field(default_factory=set) + """Governor dimensions referenced directly in the predicate, but not + necessarily constrained to the same value in all logic branches. + """ + def __post_init__(self) -> None: self.predicate.gather_required_columns(self.columns) # Extract the data ID implied by the predicate; we can use the governor # dimensions in that to constrain the collections we search for # datasets later. - self.predicate.visit(_DataIdExtractionVisitor(self.constraint_data_id, self.messages)) + self.predicate.visit( + _DataIdExtractionVisitor(self.constraint_data_id, self.messages, self.governors_referenced) + ) def iter_mandatory(self) -> Iterator[DimensionElement]: """Return an iterator over the dimension elements that must be joined @@ -304,11 +311,17 @@ class _DataIdExtractionVisitor( Dictionary to populate in place. messages : `list` [ `str` ] List of diagnostic messages to populate in place. + governor_references : `set` [ `str` ] + Set of the names of governor dimension names that were referenced + directly. This includes dimensions that were constrained to different + values in different logic branches, and hence not included in + ``data_id``. """ - def __init__(self, data_id: dict[str, DataIdValue], messages: list[str]): + def __init__(self, data_id: dict[str, DataIdValue], messages: list[str], governor_references: set[str]): self.data_id = data_id self.messages = messages + self.governor_references = governor_references def visit_comparison( self, @@ -317,6 +330,8 @@ def visit_comparison( b: qt.ColumnExpression, flags: PredicateVisitFlags, ) -> None: + k_a, v_a = a.visit(self) + k_b, v_b = b.visit(self) if flags & PredicateVisitFlags.HAS_OR_SIBLINGS: return None if flags & PredicateVisitFlags.INVERTED: @@ -326,8 +341,6 @@ def visit_comparison( return None if operator != "==": return None - k_a, v_a = a.visit(self) - k_b, v_b = b.visit(self) if k_a is not None and v_b is not None: key = k_a value = v_b @@ -341,18 +354,28 @@ def visit_comparison( return None def visit_binary_expression(self, expression: qt.BinaryExpression) -> tuple[None, None]: + expression.a.visit(self) + expression.b.visit(self) return None, None def visit_unary_expression(self, expression: qt.UnaryExpression) -> tuple[None, None]: + expression.operand.visit(self) return None, None def visit_literal(self, expression: qt.ColumnLiteral) -> tuple[None, Any]: return None, expression.get_literal_value() def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> tuple[str, None]: + if expression.dimension.governor is expression.dimension: + self.governor_references.add(expression.dimension.name) return expression.dimension.name, None def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> tuple[None, None]: + if ( + expression.element.governor is expression.element + and expression.field in expression.element.alternate_keys.names + ): + self.governor_references.add(expression.element.name) return None, None def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> tuple[None, None]: diff --git a/python/lsst/daf/butler/queries/tree/_column_reference.py b/python/lsst/daf/butler/queries/tree/_column_reference.py index 6f67ef66eb..f5103d0af4 100644 --- a/python/lsst/daf/butler/queries/tree/_column_reference.py +++ b/python/lsst/daf/butler/queries/tree/_column_reference.py @@ -62,7 +62,7 @@ def gather_required_columns(self, columns: ColumnSet) -> None: columns.update_dimensions(self.dimension.minimal_group) def gather_governors(self, governors: set[str]) -> None: - if self.dimension.governor is not None: + if self.dimension.governor is not None and self.dimension.governor is not self.dimension: governors.add(self.dimension.governor.name) @property diff --git a/python/lsst/daf/butler/queries/visitors.py b/python/lsst/daf/butler/queries/visitors.py index 5e77161d15..5ee2979b4d 100644 --- a/python/lsst/daf/butler/queries/visitors.py +++ b/python/lsst/daf/butler/queries/visitors.py @@ -192,7 +192,7 @@ class PredicateVisitor(Generic[_A, _O, _L]): ----- The concrete `PredicateLeaf` types are only semi-public (they appear in the serialized form of a `Predicate`, but their types should not generally - be referenced directly outside of the module in which they are defined. + be referenced directly outside of the module in which they are defined). As a result, visiting these objects unpacks their attributes into the visit method arguments. """ diff --git a/python/lsst/daf/butler/tests/butler_queries.py b/python/lsst/daf/butler/tests/butler_queries.py index 2d3d83f0eb..970b42a770 100644 --- a/python/lsst/daf/butler/tests/butler_queries.py +++ b/python/lsst/daf/butler/tests/butler_queries.py @@ -1859,6 +1859,39 @@ def test_dataset_queries(self) -> None: self.assertEqual(rows[0]["visit"], 1) self.assertEqual(rows[0]["dt.collection"], "run1") + def test_multiple_instrument_queries(self) -> None: + """Test that multiple-instrument queries are not rejected as having + governor dimension ambiguities. + """ + butler = self.make_butler("base.yaml") + butler.registry.insertDimensionData("instrument", {"name": "Cam2"}) + self.assertCountEqual( + butler.query_data_ids(["detector"], where="instrument='Cam1' OR instrument='Cam2'"), + [ + DataCoordinate.standardize(instrument="Cam1", detector=n, universe=butler.dimensions) + for n in range(1, 5) + ], + ) + self.assertCountEqual( + butler.query_data_ids( + ["detector"], + where="(instrument='Cam1' OR instrument='Cam2') AND visit.region OVERLAPS region", + bind={"region": Region.from_ivoa_pos("CIRCLE 320. -0.25 10.")}, + explain=False, + ), + # No visits in this test dataset means no result, but the point of + # the test is just that the query can be constructed at all. + [], + ) + self.assertCountEqual( + butler.query_data_ids( + ["instrument"], + where="(instrument='Cam1' AND detector=2) OR (instrument='Cam2' AND detector=500)", + explain=False, + ), + [DataCoordinate.standardize(instrument="Cam1", universe=butler.dimensions)], + ) + def _get_exposure_ids_from_dimension_records(dimension_records: Iterable[DimensionRecord]) -> list[int]: output = []