diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 95a6bfa950..0e78a76794 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -30,10 +30,9 @@ __all__ = () -import contextlib import itertools from abc import abstractmethod -from collections import defaultdict, namedtuple +from collections import namedtuple from collections.abc import Iterable, Iterator, Set from typing import TYPE_CHECKING, Any, TypeVar, cast @@ -199,72 +198,24 @@ def __init__( self._collectionIdName = collectionIdName self._records: dict[K, CollectionRecord[K]] = {} # indexed by record ID self._dimensions = dimensions + self._full_fetch = False # True if cache contains everything. def refresh(self) -> None: # Docstring inherited from CollectionManager. - sql = sqlalchemy.sql.select( - *(list(self._tables.collection.columns) + list(self._tables.run.columns)) - ).select_from(self._tables.collection.join(self._tables.run, isouter=True)) - # Extract _all_ chain mappings as well - chain_sql = sqlalchemy.sql.select( - self._tables.collection_chain.columns["parent"], - self._tables.collection_chain.columns["position"], - self._tables.collection_chain.columns["child"], - ).select_from(self._tables.collection_chain) + # We just reset the cache here but do not retrieve any records. + self._full_fetch = False + self._setRecordCache([]) - with self._db.transaction(): - with self._db.query(sql) as sql_result: - sql_rows = sql_result.mappings().fetchall() - with self._db.query(chain_sql) as sql_result: - chain_rows = sql_result.mappings().fetchall() - - # Build all chain definitions. - chains_defs: dict[K, list[tuple[int, K]]] = defaultdict(list) - for row in chain_rows: - chains_defs[row["parent"]].append((row["position"], row["child"])) - - # Put found records into a temporary instead of updating self._records - # in place, for exception safety. - records: list[CollectionRecord] = [] - TimespanReprClass = self._db.getTimespanRepresentation() - id_to_name: dict[K, str] = {} - chained_ids: list[K] = [] - for row in sql_rows: - collection_id = row[self._tables.collection.columns[self._collectionIdName]] - name = row[self._tables.collection.columns.name] - id_to_name[collection_id] = name - type = CollectionType(row["type"]) - record: CollectionRecord - if type is CollectionType.RUN: - record = RunRecord( - key=collection_id, - name=name, - host=row[self._tables.run.columns.host], - timespan=TimespanReprClass.extract(row), - ) - records.append(record) - elif type is CollectionType.CHAINED: - # Need to delay chained collection construction until all names - # are known. - chained_ids.append(collection_id) - else: - record = CollectionRecord(key=collection_id, name=name, type=type) - records.append(record) - - for chained_id in chained_ids: - children_names = [id_to_name[child_id] for _, child_id in sorted(chains_defs[chained_id])] - record = ChainedCollectionRecord( - key=chained_id, - name=id_to_name[chained_id], - children=children_names, - ) - records.append(record) - - self._setRecordCache(records) + def _fetch_all(self) -> None: + """Retrieve all records into cache if not done so yet.""" + if not self._full_fetch: + records = self._fetch_by_key(None) + self._setRecordCache(records) + self._full_fetch = True def register( self, name: str, type: CollectionType, doc: str | None = None - ) -> tuple[CollectionRecord, bool]: + ) -> tuple[CollectionRecord[K], bool]: # Docstring inherited from CollectionManager. registered = False record = self._getByName(name) @@ -323,12 +274,31 @@ def find(self, name: str) -> CollectionRecord[K]: raise MissingCollectionError(f"No collection with name '{name}' found.") return result + def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]: + """Return multiple records given their names.""" + names = list(names) + # To protect against potential races in cache updates. + records = {} + for name in names: + records[name] = self._get_cached_name(name) + fetch_names = [name for name, record in records.items() if record is None] + for record in self._fetch_by_name(fetch_names): + records[record.name] = record + missing_names = [name for name, record in records.items() if record is None] + if len(missing_names) == 1: + raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.") + elif len(missing_names) > 1: + raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.") + return [cast(CollectionRecord[K], records[name]) for name in names] + def __getitem__(self, key: Any) -> CollectionRecord[K]: # Docstring inherited from CollectionManager. - try: - return self._records[key] - except KeyError as err: - raise MissingCollectionError(f"Collection with key '{key}' not found.") from err + if (record := self._records.get(key)) is not None: + return record + if records := self._fetch_by_key([key]): + return records[0] + else: + raise MissingCollectionError(f"Collection with key '{key}' not found.") def resolve_wildcard( self, @@ -353,20 +323,25 @@ def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[Collect yield record if flatten_chains and record.type is CollectionType.CHAINED: done.add(record.name) - for name in cast(ChainedCollectionRecord[K], record).children: + for child in self._find_many(cast(ChainedCollectionRecord[K], record).children): # flake8 can't tell that we only delete this closure when # we're totally done with it. - yield from resolve_nested(self.find(name), done) # noqa: F821 + yield from resolve_nested(child, done) # noqa: F821 result: list[CollectionRecord[K]] = [] + # If we have wildcard or ellipsis we need to read everything in memory. + if wildcard.patterns: + self._fetch_all() + if wildcard.patterns is ...: for record in self._records.values(): result.extend(resolve_nested(record, done)) del resolve_nested return result - for name in wildcard.strings: - result.extend(resolve_nested(self.find(name), done)) + if wildcard.strings: + for record in self._find_many(wildcard.strings): + result.extend(resolve_nested(record, done)) if wildcard.patterns: for record in self._records.values(): if any(p.fullmatch(record.name) for p in wildcard.patterns): @@ -374,7 +349,7 @@ def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[Collect del resolve_nested return result - def getDocumentation(self, key: Any) -> str | None: + def getDocumentation(self, key: K) -> str | None: # Docstring inherited from CollectionManager. sql = ( sqlalchemy.sql.select(self._tables.collection.columns.doc) @@ -384,7 +359,7 @@ def getDocumentation(self, key: Any) -> str | None: with self._db.query(sql) as sql_result: return sql_result.scalar() - def setDocumentation(self, key: Any, doc: str | None) -> None: + def setDocumentation(self, key: K, doc: str | None) -> None: # Docstring inherited from CollectionManager. self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) @@ -404,12 +379,33 @@ def _removeCachedRecord(self, record: CollectionRecord[K]) -> None: """Remove single record from cache.""" del self._records[record.key] - @abstractmethod def _getByName(self, name: str) -> CollectionRecord[K] | None: """Find collection record given collection name.""" + if (record := self._get_cached_name(name)) is not None: + return record + records = self._fetch_by_name([name]) + for record in records: + self._addCachedRecord(record) + return records[0] if records else None + + @abstractmethod + def _get_cached_name(self, name: str) -> CollectionRecord[K] | None: + """Find cached collection record given its name.""" + raise NotImplementedError() + + @abstractmethod + def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]: + """Fetch collection record from database given its name.""" + raise NotImplementedError() + + @abstractmethod + def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]: + """Fetch collection record from database given its key, or fetch all + collctions if argument is None. + """ raise NotImplementedError() - def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord[K]]: + def getParentChains(self, key: K) -> Iterator[ChainedCollectionRecord[K]]: # Docstring inherited from CollectionManager. table = self._tables.collection_chain sql = ( @@ -419,11 +415,13 @@ def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord[K]]: ) with self._db.query(sql) as sql_result: parent_keys = sql_result.scalars().all() - for key in parent_keys: - # TODO: Just in case cached records miss new parent collections. - # This is temporary, will replace with non-cached records soon. - with contextlib.suppress(KeyError): - yield cast(ChainedCollectionRecord[K], self._records[key]) + # TODO: It would be more efficient to write a single query that both + # finds parents and all their children, but for now we do not care + # much about efficiency. Also the only client of this method does not + # need full records, only parent collection names, maybe we should + # change this method to return names instead. + for record in self._fetch_by_key(parent_keys): + yield cast(ChainedCollectionRecord[K], record) def update_chain( self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False diff --git a/python/lsst/daf/butler/registry/collections/nameKey.py b/python/lsst/daf/butler/registry/collections/nameKey.py index e5e635e61c..7336558cc9 100644 --- a/python/lsst/daf/butler/registry/collections/nameKey.py +++ b/python/lsst/daf/butler/registry/collections/nameKey.py @@ -26,16 +26,17 @@ # along with this program. If not, see . from __future__ import annotations -from ... import ddl - __all__ = ["NameKeyCollectionManager"] +from collections.abc import Iterable, Mapping from typing import TYPE_CHECKING, Any import sqlalchemy +from ... import ddl from ..._timespan import TimespanDatabaseRepresentation -from ..interfaces import VersionTuple +from .._collection_type import CollectionType +from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple from ._base import ( CollectionTablesTuple, DefaultCollectionManager, @@ -44,7 +45,7 @@ ) if TYPE_CHECKING: - from ..interfaces import CollectionRecord, Database, DimensionRecordStorageManager, StaticTablesContext + from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext _KEY_FIELD_SPEC = ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, primaryKey=True) @@ -68,7 +69,7 @@ def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> ) -class NameKeyCollectionManager(DefaultCollectionManager): +class NameKeyCollectionManager(DefaultCollectionManager[str]): """A `CollectionManager` implementation that uses collection names for primary/foreign keys and aggressively loads all collection/run records in the database into memory. @@ -152,10 +153,110 @@ def addRunForeignKey( ) return copy - def _getByName(self, name: str) -> CollectionRecord | None: - # Docstring inherited from DefaultCollectionManager. + def _get_cached_name(self, name: str) -> CollectionRecord[str] | None: + # Docstring inherited from base class. return self._records.get(name) + def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]: + # Docstring inherited from base class. + return self._fetch_by_key(names) + + def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]: + # Docstring inherited from base class. + sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from( + self._tables.collection.join(self._tables.run, isouter=True) + ) + + chain_sql = sqlalchemy.sql.select( + self._tables.collection_chain.columns["parent"], + self._tables.collection_chain.columns["position"], + self._tables.collection_chain.columns["child"], + ) + + records: list[CollectionRecord[str]] = [] + # We want to keep transactions as short as possible. When we fetch + # everything we want to quickly fetch things into memory and finish + # transaction. When we fetch just few records we need to process result + # of the first query before we can run the second one. + if collection_ids is not None: + sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids)) + with self._db.transaction(): + with self._db.query(sql) as sql_result: + sql_rows = sql_result.mappings().fetchall() + + records, chained_ids = self._rows_to_records(sql_rows) + + if chained_ids: + # Retrieve chained collection compositions + chain_sql = chain_sql.where( + self._tables.collection_chain.columns["parent"].in_(chained_ids) + ) + with self._db.query(chain_sql) as sql_result: + chain_rows = sql_result.mappings().fetchall() + + records += self._rows_to_chains(chain_rows, chained_ids) + + else: + with self._db.transaction(): + with self._db.query(sql) as sql_result: + sql_rows = sql_result.mappings().fetchall() + with self._db.query(chain_sql) as sql_result: + chain_rows = sql_result.mappings().fetchall() + + records, chained_ids = self._rows_to_records(sql_rows) + records += self._rows_to_chains(chain_rows, chained_ids) + + return records + + def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]: + """Convert rows returned from collection query to a list of records + and a list chained collection names. + """ + records: list[CollectionRecord[str]] = [] + TimespanReprClass = self._db.getTimespanRepresentation() + chained_ids: list[str] = [] + for row in rows: + name = row[self._tables.collection.columns.name] + type = CollectionType(row["type"]) + record: CollectionRecord[str] + if type is CollectionType.RUN: + record = RunRecord[str]( + key=name, + name=name, + host=row[self._tables.run.columns.host], + timespan=TimespanReprClass.extract(row), + ) + records.append(record) + elif type is CollectionType.CHAINED: + # Need to delay chained collection construction until to + # fetch their children names. + chained_ids.append(name) + else: + record = CollectionRecord[str](key=name, name=name, type=type) + records.append(record) + + return records, chained_ids + + def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]: + """Convert rows returned from collection chain query to a list of + records. + """ + chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} + for row in rows: + chains_defs[row["parent"]].append((row["position"], row["child"])) + + records: list[CollectionRecord[str]] = [] + for name, children in chains_defs.items(): + children_names = [child for _, child in sorted(children)] + record = ChainedCollectionRecord[str]( + key=name, + name=name, + children=children_names, + ) + records.append(record) + + return records + @classmethod def currentVersions(cls) -> list[VersionTuple]: # Docstring inherited from VersionedExtension. diff --git a/python/lsst/daf/butler/registry/collections/synthIntKey.py b/python/lsst/daf/butler/registry/collections/synthIntKey.py index 8e49140c8d..2e1bf5f758 100644 --- a/python/lsst/daf/butler/registry/collections/synthIntKey.py +++ b/python/lsst/daf/butler/registry/collections/synthIntKey.py @@ -30,13 +30,14 @@ __all__ = ["SynthIntKeyCollectionManager"] -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from typing import TYPE_CHECKING, Any import sqlalchemy from ..._timespan import TimespanDatabaseRepresentation -from ..interfaces import CollectionRecord, VersionTuple +from .._collection_type import CollectionType +from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple from ._base import ( CollectionTablesTuple, DefaultCollectionManager, @@ -73,7 +74,7 @@ def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> ) -class SynthIntKeyCollectionManager(DefaultCollectionManager): +class SynthIntKeyCollectionManager(DefaultCollectionManager[int]): """A `CollectionManager` implementation that uses synthetic primary key (auto-incremented integer) for collections table. @@ -184,7 +185,7 @@ def addRunForeignKey( ) return copy - def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: + def _setRecordCache(self, records: Iterable[CollectionRecord[int]]) -> None: """Set internal record cache to contain given records, old cached records will be removed. """ @@ -194,20 +195,135 @@ def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: self._records[record.key] = record self._nameCache[record.name] = record - def _addCachedRecord(self, record: CollectionRecord) -> None: + def _addCachedRecord(self, record: CollectionRecord[int]) -> None: """Add single record to cache.""" self._records[record.key] = record self._nameCache[record.name] = record - def _removeCachedRecord(self, record: CollectionRecord) -> None: + def _removeCachedRecord(self, record: CollectionRecord[int]) -> None: """Remove single record from cache.""" del self._records[record.key] del self._nameCache[record.name] - def _getByName(self, name: str) -> CollectionRecord | None: - # Docstring inherited from DefaultCollectionManager. + def _get_cached_name(self, name: str) -> CollectionRecord[int] | None: + # Docstring inherited from base class. return self._nameCache.get(name) + def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]: + # Docstring inherited from base class. + return self._fetch("name", names) + + def _fetch_by_key(self, collection_ids: Iterable[int] | None) -> list[CollectionRecord[int]]: + # Docstring inherited from base class. + return self._fetch(self._collectionIdName, collection_ids) + + def _fetch( + self, column_name: str, collections: Iterable[int | str] | None + ) -> list[CollectionRecord[int]]: + collection_chain = self._tables.collection_chain + collection = self._tables.collection + sql = sqlalchemy.sql.select(*collection.columns, *self._tables.run.columns).select_from( + collection.join(self._tables.run, isouter=True) + ) + + chain_sql = ( + sqlalchemy.sql.select( + collection_chain.columns["parent"], + collection_chain.columns["position"], + collection.columns["name"].label("child_name"), + ) + .select_from(collection_chain) + .join( + collection, + onclause=collection_chain.columns["child"] == collection.columns[self._collectionIdName], + ) + ) + + records: list[CollectionRecord[int]] = [] + # We want to keep transactions as short as possible. When we fetch + # everything we want to quickly fetch things into memory and finish + # transaction. When we fetch just few records we need to process first + # query before wi can run second one, + if collections is not None: + sql = sql.where(collection.columns[column_name].in_(collections)) + with self._db.transaction(): + with self._db.query(sql) as sql_result: + sql_rows = sql_result.mappings().fetchall() + + records, chained_ids = self._rows_to_records(sql_rows) + + if chained_ids: + chain_sql = chain_sql.where(collection_chain.columns["parent"].in_(list(chained_ids))) + + with self._db.query(chain_sql) as sql_result: + chain_rows = sql_result.mappings().fetchall() + + records += self._rows_to_chains(chain_rows, chained_ids) + + else: + with self._db.transaction(): + with self._db.query(sql) as sql_result: + sql_rows = sql_result.mappings().fetchall() + with self._db.query(chain_sql) as sql_result: + chain_rows = sql_result.mappings().fetchall() + + records, chained_ids = self._rows_to_records(sql_rows) + records += self._rows_to_chains(chain_rows, chained_ids) + + return records + + def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[int]], dict[int, str]]: + """Convert rows returned from collection query to a list of records + and a dict chained collection names. + """ + records: list[CollectionRecord[int]] = [] + chained_ids: dict[int, str] = {} + TimespanReprClass = self._db.getTimespanRepresentation() + for row in rows: + key: int = row[self._collectionIdName] + name: str = row[self._tables.collection.columns.name] + type = CollectionType(row["type"]) + record: CollectionRecord[int] + if type is CollectionType.RUN: + record = RunRecord[int]( + key=key, + name=name, + host=row[self._tables.run.columns.host], + timespan=TimespanReprClass.extract(row), + ) + records.append(record) + elif type is CollectionType.CHAINED: + # Need to delay chained collection construction until to + # fetch their children names. + chained_ids[key] = name + else: + record = CollectionRecord[int](key=key, name=name, type=type) + records.append(record) + return records, chained_ids + + def _rows_to_chains( + self, rows: Iterable[Mapping], chained_ids: dict[int, str] + ) -> list[CollectionRecord[int]]: + """Convert rows returned from collection chain query to a list of + records. + """ + chains_defs: dict[int, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} + for row in rows: + chains_defs[row["parent"]].append((row["position"], row["child_name"])) + + records: list[CollectionRecord[int]] = [] + for key, children in chains_defs.items(): + name = chained_ids[key] + children_names = [child for _, child in sorted(children)] + record = ChainedCollectionRecord[int]( + key=key, + name=name, + children=children_names, + ) + records.append(record) + + return records + @classmethod def currentVersions(cls) -> list[VersionTuple]: # Docstring inherited from VersionedExtension.