diff --git a/python/lsst/daf/butler/_butler.py b/python/lsst/daf/butler/_butler.py index 4611e6e556..54c25d84b9 100644 --- a/python/lsst/daf/butler/_butler.py +++ b/python/lsst/daf/butler/_butler.py @@ -800,13 +800,16 @@ def get_dataset_type(self, name: str) -> DatasetType: raise NotImplementedError() @abstractmethod - def get_dataset(self, id: DatasetId) -> DatasetRef | None: + def get_dataset(self, id: DatasetId, storage_class: str | StorageClass | None) -> DatasetRef | None: """Retrieve a Dataset entry. Parameters ---------- id : `DatasetId` The unique identifier for the dataset. + storage_class : `str` or `StorageClass` or `None` + A storage class to use when creating the returned entry. If given + it must be compatible with the default storage class. Returns ------- @@ -824,6 +827,7 @@ def find_dataset( *, collections: str | Sequence[str] | None = None, timespan: Timespan | None = None, + storage_class: str | StorageClass | None = None, datastore_records: bool = False, **kwargs: Any, ) -> DatasetRef | None: @@ -851,6 +855,9 @@ def find_dataset( A timespan that the validity range of the dataset must overlap. If not provided, any `~CollectionType.CALIBRATION` collections matched by the ``collections`` argument will not be searched. + storage_class : `str` or `StorageClass` or `None` + A storage class to use when creating the returned entry. If given + it must be compatible with the default storage class. **kwargs Additional keyword arguments passed to `DataCoordinate.standardize` to convert ``dataId`` to a true diff --git a/python/lsst/daf/butler/direct_butler.py b/python/lsst/daf/butler/direct_butler.py index 25817658a6..bd85516672 100644 --- a/python/lsst/daf/butler/direct_butler.py +++ b/python/lsst/daf/butler/direct_butler.py @@ -1321,8 +1321,13 @@ def getURI( def get_dataset_type(self, name: str) -> DatasetType: return self._registry.getDatasetType(name) - def get_dataset(self, id: DatasetId) -> DatasetRef | None: - return self._registry.getDataset(id) + def get_dataset( + self, id: DatasetId, storage_class: str | StorageClass | None = None + ) -> DatasetRef | None: + ref = self._registry.getDataset(id) + if ref is not None and storage_class: + ref = ref.overrideStorageClass(storage_class) + return ref def find_dataset( self, @@ -1331,6 +1336,7 @@ def find_dataset( *, collections: str | Sequence[str] | None = None, timespan: Timespan | None = None, + storage_class: str | StorageClass | None = None, datastore_records: bool = False, **kwargs: Any, ) -> DatasetRef | None: @@ -1342,7 +1348,7 @@ def find_dataset( actual_type = dataset_type data_id, kwargs = self._rewrite_data_id(data_id, actual_type, **kwargs) - return self._registry.findDataset( + ref = self._registry.findDataset( dataset_type, data_id, collections=collections, @@ -1350,6 +1356,9 @@ def find_dataset( dataset_records=datastore_records, **kwargs, ) + if ref is not None and storage_class is not None: + ref = ref.overrideStorageClass(storage_class) + return ref def retrieveArtifacts( self, diff --git a/python/lsst/daf/butler/remote_butler/_remote_butler.py b/python/lsst/daf/butler/remote_butler/_remote_butler.py index 6363591f83..8ba367bddf 100644 --- a/python/lsst/daf/butler/remote_butler/_remote_butler.py +++ b/python/lsst/daf/butler/remote_butler/_remote_butler.py @@ -219,9 +219,18 @@ def get_dataset_type(self, name: str) -> DatasetType: response.raise_for_status() return DatasetType.from_simple(SerializedDatasetType(**response.json()), universe=self.dimensions) - def get_dataset(self, id: DatasetId) -> DatasetRef | None: + def get_dataset( + self, id: DatasetId, storage_class: str | StorageClass | None = None + ) -> DatasetRef | None: path = f"dataset/{id}" - response = self._client.get(self._get_url(path)) + if isinstance(storage_class, StorageClass): + storage_class_name = storage_class.name + elif storage_class: + storage_class_name = storage_class + params: dict[str, str] = {} + if storage_class: + params["storage_class"] = storage_class_name + response = self._client.get(self._get_url(path), params=params) response.raise_for_status() if response.json() is None: return None @@ -234,6 +243,7 @@ def find_dataset( *, collections: str | Sequence[str] | None = None, timespan: Timespan | None = None, + storage_class: str | StorageClass | None = None, datastore_records: bool = False, **kwargs: Any, ) -> DatasetRef | None: @@ -251,13 +261,18 @@ def find_dataset( if isinstance(dataset_type, DatasetType): dataset_type = dataset_type.name + if isinstance(storage_class, StorageClass): + storage_class = storage_class.name + query = FindDatasetModel( - data_id=self._simplify_dataId(data_id, **kwargs), collections=wildcards.strings + data_id=self._simplify_dataId(data_id, **kwargs), + collections=wildcards.strings, + storage_class=storage_class, ) path = f"find_dataset/{dataset_type}" response = self._client.post( - self._get_url(path), json=query.model_dump(mode="json", exclude_unset=True) + self._get_url(path), json=query.model_dump(mode="json", exclude_unset=True, exclude_defaults=True) ) response.raise_for_status() diff --git a/python/lsst/daf/butler/remote_butler/server/_server.py b/python/lsst/daf/butler/remote_butler/server/_server.py index 0e036ab67a..92798b2628 100644 --- a/python/lsst/daf/butler/remote_butler/server/_server.py +++ b/python/lsst/daf/butler/remote_butler/server/_server.py @@ -116,10 +116,12 @@ def get_dataset_type( response_model_exclude_defaults=True, response_model_exclude_none=True, ) -def get_dataset(id: uuid.UUID, factory: Factory = Depends(factory_dependency)) -> SerializedDatasetRef | None: +def get_dataset( + id: uuid.UUID, storage_class: str | None = None, factory: Factory = Depends(factory_dependency) +) -> SerializedDatasetRef | None: """Return a single dataset reference.""" butler = factory.create_butler() - ref = butler.get_dataset(id) + ref = butler.get_dataset(id, storage_class=storage_class) if ref is not None: return ref.to_simple() # This could raise a 404 since id is not found. The standard implementation @@ -150,5 +152,7 @@ def find_dataset( data_id = query.data_id.dataId butler = factory.create_butler() - ref = butler.find_dataset(dataset_type, None, collections=collection_query, **data_id) + ref = butler.find_dataset( + dataset_type, None, collections=collection_query, storage_class=query.storage_class, **data_id + ) return ref.to_simple() if ref else None diff --git a/python/lsst/daf/butler/remote_butler/server/_server_models.py b/python/lsst/daf/butler/remote_butler/server/_server_models.py index 24a20829e6..d9200976b1 100644 --- a/python/lsst/daf/butler/remote_butler/server/_server_models.py +++ b/python/lsst/daf/butler/remote_butler/server/_server_models.py @@ -37,3 +37,4 @@ class FindDatasetModel(_BaseModelCompat): data_id: SerializedDataCoordinate collections: list[str] + storage_class: str | None diff --git a/tests/test_server.py b/tests/test_server.py index 05daac7e06..4622a5a203 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -38,7 +38,7 @@ TestClient = None app = None -from lsst.daf.butler import Butler, DataCoordinate, DatasetRef +from lsst.daf.butler import Butler, DataCoordinate, DatasetRef, StorageClassFactory from lsst.daf.butler.tests import DatastoreMock from lsst.daf.butler.tests.utils import MetricTestRepo, makeTestTempDir, removeTestTempDir @@ -64,6 +64,8 @@ class ButlerClientServerTestCase(unittest.TestCase): @classmethod def setUpClass(cls): + cls.storageClassFactory = StorageClassFactory() + # First create a butler and populate it. cls.root = makeTestTempDir(TESTDIR) cls.repo = MetricTestRepo(root=cls.root, configFile=os.path.join(TESTDIR, "config/basic/butler.yaml")) @@ -106,6 +108,8 @@ def test_get_dataset_type(self): self.assertEqual(bias_type.name, "bias") def test_find_dataset(self): + storage_class = self.storageClassFactory.getStorageClass("Exposure") + ref = self.butler.find_dataset("bias", collections="imported_g", detector=1, instrument="Cam1") self.assertIsInstance(ref, DatasetRef) self.assertEqual(ref.id, uuid.UUID("e15ab039-bc8b-4135-87c5-90902a7c0b22")) @@ -123,6 +127,7 @@ def test_find_dataset(self): ref.datasetType, DataCoordinate.standardize(detector=1, instrument="Cam1", universe=self.butler.dimensions), collections="imported_g", + storage_class=storage_class, ) self.assertEqual(ref_new, ref) @@ -138,8 +143,15 @@ def test_find_dataset(self): ) self.assertEqual(ref2, ref3) + # The test datasets are all Exposure so storage class conversion + # can not be tested until we fix that. For now at least test the + # code paths. + bias = self.butler.get_dataset(ref.id, storage_class=storage_class) + self.assertEqual(bias.datasetType.storageClass, storage_class) + # Unknown dataset should not fail. self.assertIsNone(self.butler.get_dataset(uuid.uuid4())) + self.assertIsNone(self.butler.get_dataset(uuid.uuid4(), storage_class="NumpyArray")) if __name__ == "__main__":