diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 0e78a76794..1c3587d1ab 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -405,24 +405,6 @@ def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRe """ raise NotImplementedError() - def getParentChains(self, key: K) -> Iterator[ChainedCollectionRecord[K]]: - # Docstring inherited from CollectionManager. - table = self._tables.collection_chain - sql = ( - sqlalchemy.sql.select(table.columns["parent"]) - .select_from(table) - .where(table.columns["child"] == key) - ) - with self._db.query(sql) as sql_result: - parent_keys = sql_result.scalars().all() - # TODO: It would be more efficient to write a single query that both - # finds parents and all their children, but for now we do not care - # much about efficiency. Also the only client of this method does not - # need full records, only parent collection names, maybe we should - # change this method to return names instead. - for record in self._fetch_by_key(parent_keys): - yield cast(ChainedCollectionRecord[K], record) - def update_chain( self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False ) -> ChainedCollectionRecord[K]: diff --git a/python/lsst/daf/butler/registry/collections/nameKey.py b/python/lsst/daf/butler/registry/collections/nameKey.py index 7336558cc9..d8c50cce2f 100644 --- a/python/lsst/daf/butler/registry/collections/nameKey.py +++ b/python/lsst/daf/butler/registry/collections/nameKey.py @@ -153,6 +153,18 @@ def addRunForeignKey( ) return copy + def getParentChains(self, key: str) -> set[str]: + # Docstring inherited from CollectionManager. + table = self._tables.collection_chain + sql = ( + sqlalchemy.sql.select(table.columns["parent"]) + .select_from(table) + .where(table.columns["child"] == key) + ) + with self._db.query(sql) as sql_result: + parent_names = set(sql_result.scalars().all()) + return parent_names + def _get_cached_name(self, name: str) -> CollectionRecord[str] | None: # Docstring inherited from base class. return self._records.get(name) diff --git a/python/lsst/daf/butler/registry/collections/synthIntKey.py b/python/lsst/daf/butler/registry/collections/synthIntKey.py index 2e1bf5f758..d2edcaae88 100644 --- a/python/lsst/daf/butler/registry/collections/synthIntKey.py +++ b/python/lsst/daf/butler/registry/collections/synthIntKey.py @@ -185,6 +185,20 @@ def addRunForeignKey( ) return copy + def getParentChains(self, key: int) -> set[str]: + # Docstring inherited from CollectionManager. + chain = self._tables.collection_chain + collection = self._tables.collection + sql = ( + sqlalchemy.sql.select(collection.columns["name"]) + .select_from(collection) + .join(chain, onclause=collection.columns[self._collectionIdName] == chain.columns["parent"]) + .where(chain.columns["child"] == key) + ) + with self._db.query(sql) as sql_result: + parent_names = set(sql_result.scalars().all()) + return parent_names + def _setRecordCache(self, records: Iterable[CollectionRecord[int]]) -> None: """Set internal record cache to contain given records, old cached records will be removed. diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py index 3cce276b6b..c07b894adc 100644 --- a/python/lsst/daf/butler/registry/interfaces/_collections.py +++ b/python/lsst/daf/butler/registry/interfaces/_collections.py @@ -36,7 +36,7 @@ ] from abc import abstractmethod -from collections.abc import Iterable, Iterator, Set +from collections.abc import Iterable, Set from typing import TYPE_CHECKING, Any, Generic, TypeVar from ..._timespan import Timespan @@ -563,14 +563,19 @@ def setDocumentation(self, key: _Key, doc: str | None) -> None: raise NotImplementedError() @abstractmethod - def getParentChains(self, key: _Key) -> Iterator[ChainedCollectionRecord[_Key]]: - """Find all CHAINED collections that directly contain the given + def getParentChains(self, key: _Key) -> set[str]: + """Find all CHAINED collection names that directly contain the given collection. Parameters ---------- key Internal primary key value for the collection. + + Returns + ------- + names : `set` [`str`] + Parent collection names. """ raise NotImplementedError() diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 538ffbc265..5e03938b78 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -618,12 +618,7 @@ def getCollectionParentChains(self, collection: str) -> set[str]: chains : `set` of `str` Set of `~CollectionType.CHAINED` collection names. """ - return { - record.name - for record in self._managers.collections.getParentChains( - self._managers.collections.find(collection).key - ) - } + return self._managers.collections.getParentChains(self._managers.collections.find(collection).key) def getCollectionDocumentation(self, collection: str) -> str | None: """Retrieve the documentation string for a collection.