Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-47980: Add an option to include dimension records into general query result #1135

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add an option to include dimension records into general query result …
…(DM-47980)

The `GeneralQueryResults.iter_tuples` method returned DataIds without
dimension records. In some cases (e.g. for obscore export) it would be
very useful to include records in the same result to avoid querying
them separately. New method `with_dimension_records` is added to the class
to trigger adding fields from all dimension records into returned page.
This will produce many duplicates for some dimensions (e.g. `instrument`)
but it keeps page structure simple.

This adds one attribute to the `GeneralResultSpec` class, will need some
care with Butler server compatibility.
andy-slac committed Dec 19, 2024
commit a75ae5b0c2a28073b28123da27a9694d954fad70
Original file line number Diff line number Diff line change
@@ -326,17 +326,35 @@ class GeneralResultPageConverter(ResultPageConverter): # numpydoc ignore=PR01

def __init__(self, spec: GeneralResultSpec, ctx: ResultPageConverterContext) -> None:
self.spec = spec

result_columns = spec.get_result_columns()
# In case `spec.include_dimension_records` is True then in addition to
# columns returned by the query we have to add columns from dimension
# records that are not returned by the query. These columns belong to
# either cached or skypix dimensions.
query_result_columns = set(spec.get_result_columns())
output_columns = spec.get_all_result_columns()
universe = spec.dimensions.universe
self.converters: list[_GeneralColumnConverter] = []
for column in result_columns:
for column in output_columns:
column_name = qt.ColumnSet.get_qualified_name(column.logical_table, column.field)
if column.field == TimespanDatabaseRepresentation.NAME:
self.converters.append(_TimespanGeneralColumnConverter(column_name, ctx.db))
converter: _GeneralColumnConverter
if column not in query_result_columns and column.field is not None:
# This must be a field from a cached dimension record or
# skypix record.
assert isinstance(column.logical_table, str), "Do not expect AnyDatasetType here"
element = universe[column.logical_table]
if isinstance(element, SkyPixDimension):
converter = _SkypixRecordGeneralColumnConverter(element, column.field)
else:
converter = _CachedRecordGeneralColumnConverter(
element, column.field, ctx.dimension_record_cache
)
elif column.field == TimespanDatabaseRepresentation.NAME:
converter = _TimespanGeneralColumnConverter(column_name, ctx.db)
elif column.field == "ingest_date":
self.converters.append(_TimestampGeneralColumnConverter(column_name))
converter = _TimestampGeneralColumnConverter(column_name)
else:
self.converters.append(_DefaultGeneralColumnConverter(column_name))
converter = _DefaultGeneralColumnConverter(column_name)
self.converters.append(converter)

def convert(self, raw_rows: Iterable[sqlalchemy.Row]) -> GeneralResultPage:
rows = [tuple(cvt.convert(row) for cvt in self.converters) for row in raw_rows]
@@ -422,3 +440,47 @@ def __init__(self, name: str, db: Database):
def convert(self, row: sqlalchemy.Row) -> Any:
timespan = self.timespan_class.extract(row._mapping, self.name)
return timespan


class _CachedRecordGeneralColumnConverter(_GeneralColumnConverter):
"""Helper for converting result row into a field value for cached
dimension records.

Parameters
----------
element : `DimensionElement`
Dimension element, must be of cached type.
field : `str`
Name of the field to extract from the dimension record.
cache : `DimensionRecordCache`
Cache for dimension records.
"""

def __init__(self, element: DimensionElement, field: str, cache: DimensionRecordCache) -> None:
self._record_converter = _CachedDimensionRecordRowConverter(element, cache)
self._field = field

def convert(self, row: sqlalchemy.Row) -> Any:
record = self._record_converter.convert(row)
return getattr(record, self._field)


class _SkypixRecordGeneralColumnConverter(_GeneralColumnConverter):
"""Helper for converting result row into a field value for skypix
dimension records.

Parameters
----------
element : `SkyPixDimension`
Dimension element.
field : `str`
Name of the field to extract from the dimension record.
"""

def __init__(self, element: SkyPixDimension, field: str) -> None:
self._record_converter = _SkypixDimensionRecordRowConverter(element)
self._field = field

def convert(self, row: sqlalchemy.Row) -> Any:
record = self._record_converter.convert(row)
return getattr(record, self._field)
53 changes: 43 additions & 10 deletions python/lsst/daf/butler/queries/_general_query_results.py
Original file line number Diff line number Diff line change
@@ -35,11 +35,11 @@

from .._dataset_ref import DatasetRef
from .._dataset_type import DatasetType
from ..dimensions import DataCoordinate, DimensionGroup
from ..dimensions import DataCoordinate, DimensionElement, DimensionGroup, DimensionRecord
from ._base import QueryResultsBase
from .driver import QueryDriver
from .result_specs import GeneralResultSpec
from .tree import QueryTree
from .tree import QueryTree, ResultColumn


class GeneralResultTuple(NamedTuple):
@@ -99,9 +99,9 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
fields (separated from dataset type name by dot).
"""
for page in self._driver.execute(self._spec, self._tree):
columns = tuple(str(column) for column in page.spec.get_result_columns())
columns = tuple(str(column) for column in page.spec.get_all_result_columns())
for row in page.rows:
yield dict(zip(columns, row))
yield dict(zip(columns, row, strict=True))

def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTuple]:
"""Iterate over result rows and return data coordinate, and dataset
@@ -125,14 +125,10 @@ def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTupl
run_key = f"{dataset_type.name}.run"
dataset_keys.append((dataset_type, dimensions, id_key, run_key))
for row in self:
values = tuple(
row[key] for key in itertools.chain(all_dimensions.required, all_dimensions.implied)
)
data_coordinate = DataCoordinate.from_full_values(all_dimensions, values)
data_coordinate = self._make_data_id(row, all_dimensions)
refs = []
for dataset_type, dimensions, id_key, run_key in dataset_keys:
values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied))
data_id = DataCoordinate.from_full_values(dimensions, values)
data_id = data_coordinate.subset(dimensions)
refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key]))
yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row)

@@ -141,6 +137,19 @@ def dimensions(self) -> DimensionGroup:
# Docstring inherited
return self._spec.dimensions

@property
def has_dimension_records(self) -> bool:
"""Whether all data IDs in this iterable contain dimension records."""
return self._spec.include_dimension_records

def with_dimension_records(self) -> GeneralQueryResults:
"""Return a results object for which `has_dimension_records` is
`True`.
"""
if self.has_dimension_records:
return self
return self._copy(tree=self._tree, include_dimension_records=True)

def count(self, *, exact: bool = True, discard: bool = False) -> int:
# Docstring inherited.
return self._driver.count(self._tree, self._spec, exact=exact, discard=discard)
@@ -152,3 +161,27 @@ def _copy(self, tree: QueryTree, **kwargs: Any) -> GeneralQueryResults:
def _get_datasets(self) -> frozenset[str]:
# Docstring inherited.
return frozenset(self._spec.dataset_fields)

def _make_data_id(self, row: dict[str, Any], dimensions: DimensionGroup) -> DataCoordinate:
values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied))
data_coordinate = DataCoordinate.from_full_values(dimensions, values)
if self.has_dimension_records:
records = {
name: self._make_dimension_record(row, dimensions.universe[name])
for name in dimensions.elements
}
data_coordinate = data_coordinate.expanded(records)
return data_coordinate

def _make_dimension_record(self, row: dict[str, Any], element: DimensionElement) -> DimensionRecord:
column_map = list(
zip(
element.schema.dimensions.names,
element.dimensions.names,
)
)
for field in element.schema.remainder.names:
column_map.append((field, str(ResultColumn(element.name, field))))
d = {k: row[v] for k, v in column_map}
record_cls = element.RecordClass
return record_cls(**d)
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/queries/driver.py
Original file line number Diff line number Diff line change
@@ -117,7 +117,7 @@ class GeneralResultPage:
spec: GeneralResultSpec

# Raw tabular data, with columns in the same order as
# spec.get_result_columns().
# spec.get_all_result_columns().
rows: list[tuple[Any, ...]]


32 changes: 32 additions & 0 deletions python/lsst/daf/butler/queries/result_specs.py
Original file line number Diff line number Diff line change
@@ -213,6 +213,11 @@ class GeneralResultSpec(ResultSpecBase):
dataset_fields: Mapping[str, set[DatasetFieldName]]
"""Dataset fields included in this query."""

include_dimension_records: bool = False
"""Whether to include fields for all dimension records, in addition to
explicitly specified in `dimension_fields`.
"""

find_first: bool
"""Whether this query requires find-first resolution for a dataset.

@@ -241,6 +246,33 @@ def get_result_columns(self) -> ColumnSet:
result.dimension_fields[element_name].update(fields_for_element)
for dataset_type, fields_for_dataset in self.dataset_fields.items():
result.dataset_fields[dataset_type].update(fields_for_dataset)
if self.include_dimension_records:
# This only adds record fields for non-cached and non-skypix
# elements, this is what we want when generating query. We could
# potentially add those too but it may make queries slower, so
# instead we query cached dimension records separately and add them
# to the result page in the page converter.
_add_dimension_records_to_column_set(self.dimensions, result)
return result

def get_all_result_columns(self) -> ColumnSet:
"""Return all columns that have to appear in the result. This includes
columns for all dimension records for all dimensions if
``include_dimension_records`` is `True`.

Returns
-------
columns : `ColumnSet`
Full column set.
"""
dimensions = self.dimensions
result = self.get_result_columns()
if self.include_dimension_records:
for element_name in dimensions.elements:
element = dimensions.universe[element_name]
# Non-cached dimensions are already there, but it does not harm
# to add them again.
result.dimension_fields[element_name].update(element.schema.remainder.names)
return result

@pydantic.model_validator(mode="after")
17 changes: 15 additions & 2 deletions python/lsst/daf/butler/remote_butler/_query_driver.py
Original file line number Diff line number Diff line change
@@ -257,12 +257,25 @@ def _convert_query_result_page(

def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage:
"""Convert GeneralResultModel to a general result page."""
columns = spec.get_result_columns()
columns = spec.get_all_result_columns()
# Verify that column list that we received from server matches local
# expectations (mismatch could result from different versions). Older
# server may not know about `model.columns` in that case it will be empty.
# If `model.columns` is empty then `zip(strict=True)` below will fail if
# column count is different (column names are not checked in that case).
if model.columns:
expected_column_names = [str(column) for column in columns]
if expected_column_names != model.columns:
raise ValueError(
"Inconsistent columns in general result -- "
f"server columns: {model.columns}, expected: {expected_column_names}"
)

serializers = [
columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns
]
rows = [
tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers))
tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers, strict=True))
for row in model.rows
]
return GeneralResultPage(spec=spec, rows=rows)
Original file line number Diff line number Diff line change
@@ -79,11 +79,12 @@ def convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResult

def _convert_general_result(page: GeneralResultPage) -> GeneralResultModel:
"""Convert GeneralResultPage to a serializable model."""
columns = page.spec.get_result_columns()
columns = page.spec.get_all_result_columns()
serializers = [
columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns
]
rows = [
tuple(serializer.serialize(value) for value, serializer in zip(row, serializers)) for row in page.rows
tuple(serializer.serialize(value) for value, serializer in zip(row, serializers, strict=True))
for row in page.rows
]
return GeneralResultModel(rows=rows)
return GeneralResultModel(rows=rows, columns=[str(column) for column in columns])
3 changes: 3 additions & 0 deletions python/lsst/daf/butler/remote_butler/server_models.py
Original file line number Diff line number Diff line change
@@ -313,6 +313,9 @@ class GeneralResultModel(pydantic.BaseModel):

type: Literal["general"] = "general"
rows: list[tuple[Any, ...]]
# List of column names, default is used for compatibility with older
# servers that do not set this field.
columns: list[str] = pydantic.Field(default_factory=list)


class QueryErrorResultModel(pydantic.BaseModel):
42 changes: 42 additions & 0 deletions python/lsst/daf/butler/tests/butler_queries.py
Original file line number Diff line number Diff line change
@@ -436,7 +436,9 @@ def test_general_query(self) -> None:
self.assertEqual(len(row_tuple.refs), 1)
self.assertEqual(row_tuple.refs[0].datasetType, flat)
self.assertTrue(row_tuple.refs[0].dataId.hasFull())
self.assertFalse(row_tuple.refs[0].dataId.hasRecords())
self.assertTrue(row_tuple.data_id.hasFull())
self.assertFalse(row_tuple.data_id.hasRecords())
self.assertEqual(row_tuple.data_id.dimensions, dimensions)
self.assertEqual(row_tuple.raw_row["flat.run"], "imported_g")

@@ -511,6 +513,46 @@ def test_general_query(self) -> None:
{Timespan(t1, t2), Timespan(t2, t3), Timespan(t3, None), Timespan.makeEmpty(), None},
)

dimensions = butler.dimensions["detector"].minimal_group

# Include dimension records into query.
with butler.query() as query:
query = query.join_dimensions(dimensions)
result = query.general(dimensions).order_by("detector")
rows = list(result.with_dimension_records())
self.assertEqual(
rows[0],
{
"instrument": "Cam1",
"detector": 1,
"instrument.visit_max": 1024,
"instrument.visit_system": 1,
"instrument.exposure_max": 512,
"instrument.detector_max": 4,
"instrument.class_name": "lsst.pipe.base.Instrument",
"detector.full_name": "Aa",
"detector.name_in_raft": "a",
"detector.raft": "A",
"detector.purpose": "SCIENCE",
},
)

dimensions = butler.dimensions.conform(["detector", "physical_filter"])

# DataIds should come with records.
with butler.query() as query:
query = query.join_dataset_search("flat", "imported_g")
result = query.general(dimensions, dataset_fields={"flat": ...}, find_first=True).order_by(
"detector"
)
result = result.with_dimension_records()
row_tuples = list(result.iter_tuples(flat))
self.assertEqual(len(row_tuples), 3)
for row_tuple in row_tuples:
self.assertTrue(row_tuple.data_id.hasRecords())
self.assertEqual(len(row_tuple.refs), 1)
self.assertTrue(row_tuple.refs[0].dataId.hasRecords())

def test_query_ingest_date(self) -> None:
"""Test general query returning ingest_date field."""
before_ingest = astropy.time.Time.now()