Skip to content

Commit

Permalink
Make CollectionManager.getParentChains return names instead of records
Browse files Browse the repository at this point in the history
  • Loading branch information
andy-slac committed Nov 8, 2023
1 parent 8481829 commit b1d4cb8
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 27 deletions.
18 changes: 0 additions & 18 deletions python/lsst/daf/butler/registry/collections/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,24 +405,6 @@ def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRe
"""
raise NotImplementedError()

def getParentChains(self, key: K) -> Iterator[ChainedCollectionRecord[K]]:
# Docstring inherited from CollectionManager.
table = self._tables.collection_chain
sql = (
sqlalchemy.sql.select(table.columns["parent"])
.select_from(table)
.where(table.columns["child"] == key)
)
with self._db.query(sql) as sql_result:
parent_keys = sql_result.scalars().all()
# TODO: It would be more efficient to write a single query that both
# finds parents and all their children, but for now we do not care
# much about efficiency. Also the only client of this method does not
# need full records, only parent collection names, maybe we should
# change this method to return names instead.
for record in self._fetch_by_key(parent_keys):
yield cast(ChainedCollectionRecord[K], record)

def update_chain(
self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False
) -> ChainedCollectionRecord[K]:
Expand Down
12 changes: 12 additions & 0 deletions python/lsst/daf/butler/registry/collections/nameKey.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,18 @@ def addRunForeignKey(
)
return copy

def getParentChains(self, key: str) -> set[str]:
# Docstring inherited from CollectionManager.
table = self._tables.collection_chain
sql = (
sqlalchemy.sql.select(table.columns["parent"])
.select_from(table)
.where(table.columns["child"] == key)
)
with self._db.query(sql) as sql_result:
parent_names = set(sql_result.scalars().all())
return parent_names

def _get_cached_name(self, name: str) -> CollectionRecord[str] | None:
# Docstring inherited from base class.
return self._records.get(name)
Expand Down
14 changes: 14 additions & 0 deletions python/lsst/daf/butler/registry/collections/synthIntKey.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,20 @@ def addRunForeignKey(
)
return copy

def getParentChains(self, key: int) -> set[str]:
# Docstring inherited from CollectionManager.
chain = self._tables.collection_chain
collection = self._tables.collection
sql = (
sqlalchemy.sql.select(collection.columns["name"])
.select_from(collection)
.join(chain, onclause=collection.columns[self._collectionIdName] == chain.columns["parent"])
.where(chain.columns["child"] == key)
)
with self._db.query(sql) as sql_result:
parent_names = set(sql_result.scalars().all())
return parent_names

def _setRecordCache(self, records: Iterable[CollectionRecord[int]]) -> None:
"""Set internal record cache to contain given records,
old cached records will be removed.
Expand Down
11 changes: 8 additions & 3 deletions python/lsst/daf/butler/registry/interfaces/_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
]

from abc import abstractmethod
from collections.abc import Iterable, Iterator, Set
from collections.abc import Iterable, Set
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from ..._timespan import Timespan
Expand Down Expand Up @@ -563,14 +563,19 @@ def setDocumentation(self, key: _Key, doc: str | None) -> None:
raise NotImplementedError()

@abstractmethod
def getParentChains(self, key: _Key) -> Iterator[ChainedCollectionRecord[_Key]]:
"""Find all CHAINED collections that directly contain the given
def getParentChains(self, key: _Key) -> set[str]:
"""Find all CHAINED collection names that directly contain the given
collection.
Parameters
----------
key
Internal primary key value for the collection.
Returns
-------
names : `set` [`str`]
Parent collection names.
"""
raise NotImplementedError()

Expand Down
7 changes: 1 addition & 6 deletions python/lsst/daf/butler/registry/sql_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,12 +618,7 @@ def getCollectionParentChains(self, collection: str) -> set[str]:
chains : `set` of `str`
Set of `~CollectionType.CHAINED` collection names.
"""
return {
record.name
for record in self._managers.collections.getParentChains(
self._managers.collections.find(collection).key
)
}
return self._managers.collections.getParentChains(self._managers.collections.find(collection).key)

def getCollectionDocumentation(self, collection: str) -> str | None:
"""Retrieve the documentation string for a collection.
Expand Down

0 comments on commit b1d4cb8

Please sign in to comment.