Skip to content

Commit

Permalink
Implement delayed population of collection cache.
Browse files Browse the repository at this point in the history
  • Loading branch information
andy-slac committed Nov 8, 2023
1 parent 0d87929 commit 8481829
Show file tree
Hide file tree
Showing 3 changed files with 308 additions and 93 deletions.
154 changes: 76 additions & 78 deletions python/lsst/daf/butler/registry/collections/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@

__all__ = ()

import contextlib
import itertools
from abc import abstractmethod
from collections import defaultdict, namedtuple
from collections import namedtuple
from collections.abc import Iterable, Iterator, Set
from typing import TYPE_CHECKING, Any, TypeVar, cast

Expand Down Expand Up @@ -199,72 +198,24 @@ def __init__(
self._collectionIdName = collectionIdName
self._records: dict[K, CollectionRecord[K]] = {} # indexed by record ID
self._dimensions = dimensions
self._full_fetch = False # True if cache contains everything.

def refresh(self) -> None:
# Docstring inherited from CollectionManager.
sql = sqlalchemy.sql.select(
*(list(self._tables.collection.columns) + list(self._tables.run.columns))
).select_from(self._tables.collection.join(self._tables.run, isouter=True))
# Extract _all_ chain mappings as well
chain_sql = sqlalchemy.sql.select(
self._tables.collection_chain.columns["parent"],
self._tables.collection_chain.columns["position"],
self._tables.collection_chain.columns["child"],
).select_from(self._tables.collection_chain)
# We just reset the cache here but do not retrieve any records.
self._full_fetch = False
self._setRecordCache([])

with self._db.transaction():
with self._db.query(sql) as sql_result:
sql_rows = sql_result.mappings().fetchall()
with self._db.query(chain_sql) as sql_result:
chain_rows = sql_result.mappings().fetchall()

# Build all chain definitions.
chains_defs: dict[K, list[tuple[int, K]]] = defaultdict(list)
for row in chain_rows:
chains_defs[row["parent"]].append((row["position"], row["child"]))

# Put found records into a temporary instead of updating self._records
# in place, for exception safety.
records: list[CollectionRecord] = []
TimespanReprClass = self._db.getTimespanRepresentation()
id_to_name: dict[K, str] = {}
chained_ids: list[K] = []
for row in sql_rows:
collection_id = row[self._tables.collection.columns[self._collectionIdName]]
name = row[self._tables.collection.columns.name]
id_to_name[collection_id] = name
type = CollectionType(row["type"])
record: CollectionRecord
if type is CollectionType.RUN:
record = RunRecord(
key=collection_id,
name=name,
host=row[self._tables.run.columns.host],
timespan=TimespanReprClass.extract(row),
)
records.append(record)
elif type is CollectionType.CHAINED:
# Need to delay chained collection construction until all names
# are known.
chained_ids.append(collection_id)
else:
record = CollectionRecord(key=collection_id, name=name, type=type)
records.append(record)

for chained_id in chained_ids:
children_names = [id_to_name[child_id] for _, child_id in sorted(chains_defs[chained_id])]
record = ChainedCollectionRecord(
key=chained_id,
name=id_to_name[chained_id],
children=children_names,
)
records.append(record)

self._setRecordCache(records)
def _fetch_all(self) -> None:
"""Retrieve all records into cache if not done so yet."""
if not self._full_fetch:
records = self._fetch_by_key(None)
self._setRecordCache(records)
self._full_fetch = True

def register(
self, name: str, type: CollectionType, doc: str | None = None
) -> tuple[CollectionRecord, bool]:
) -> tuple[CollectionRecord[K], bool]:
# Docstring inherited from CollectionManager.
registered = False
record = self._getByName(name)
Expand Down Expand Up @@ -323,12 +274,31 @@ def find(self, name: str) -> CollectionRecord[K]:
raise MissingCollectionError(f"No collection with name '{name}' found.")
return result

def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
"""Return multiple records given their names."""
names = list(names)
# To protect against potential races in cache updates.
records = {}
for name in names:
records[name] = self._get_cached_name(name)
fetch_names = [name for name, record in records.items() if record is None]
for record in self._fetch_by_name(fetch_names):
records[record.name] = record
missing_names = [name for name, record in records.items() if record is None]
if len(missing_names) == 1:
raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.")
elif len(missing_names) > 1:
raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.")
return [cast(CollectionRecord[K], records[name]) for name in names]

def __getitem__(self, key: Any) -> CollectionRecord[K]:
# Docstring inherited from CollectionManager.
try:
return self._records[key]
except KeyError as err:
raise MissingCollectionError(f"Collection with key '{key}' not found.") from err
if (record := self._records.get(key)) is not None:
return record
if records := self._fetch_by_key([key]):
return records[0]
else:
raise MissingCollectionError(f"Collection with key '{key}' not found.")

def resolve_wildcard(
self,
Expand All @@ -353,28 +323,33 @@ def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[Collect
yield record
if flatten_chains and record.type is CollectionType.CHAINED:
done.add(record.name)
for name in cast(ChainedCollectionRecord[K], record).children:
for child in self._find_many(cast(ChainedCollectionRecord[K], record).children):
# flake8 can't tell that we only delete this closure when
# we're totally done with it.
yield from resolve_nested(self.find(name), done) # noqa: F821
yield from resolve_nested(child, done) # noqa: F821

result: list[CollectionRecord[K]] = []

# If we have wildcard or ellipsis we need to read everything in memory.
if wildcard.patterns:
self._fetch_all()

if wildcard.patterns is ...:
for record in self._records.values():
result.extend(resolve_nested(record, done))
del resolve_nested
return result
for name in wildcard.strings:
result.extend(resolve_nested(self.find(name), done))
if wildcard.strings:
for record in self._find_many(wildcard.strings):
result.extend(resolve_nested(record, done))
if wildcard.patterns:
for record in self._records.values():
if any(p.fullmatch(record.name) for p in wildcard.patterns):
result.extend(resolve_nested(record, done))
del resolve_nested
return result

def getDocumentation(self, key: Any) -> str | None:
def getDocumentation(self, key: K) -> str | None:
# Docstring inherited from CollectionManager.
sql = (
sqlalchemy.sql.select(self._tables.collection.columns.doc)
Expand All @@ -384,7 +359,7 @@ def getDocumentation(self, key: Any) -> str | None:
with self._db.query(sql) as sql_result:
return sql_result.scalar()

def setDocumentation(self, key: Any, doc: str | None) -> None:
def setDocumentation(self, key: K, doc: str | None) -> None:
# Docstring inherited from CollectionManager.
self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc})

Expand All @@ -404,12 +379,33 @@ def _removeCachedRecord(self, record: CollectionRecord[K]) -> None:
"""Remove single record from cache."""
del self._records[record.key]

@abstractmethod
def _getByName(self, name: str) -> CollectionRecord[K] | None:
"""Find collection record given collection name."""
if (record := self._get_cached_name(name)) is not None:
return record
records = self._fetch_by_name([name])
for record in records:
self._addCachedRecord(record)
return records[0] if records else None

@abstractmethod
def _get_cached_name(self, name: str) -> CollectionRecord[K] | None:
"""Find cached collection record given its name."""
raise NotImplementedError()

@abstractmethod
def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
"""Fetch collection record from database given its name."""
raise NotImplementedError()

@abstractmethod
def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]:
"""Fetch collection record from database given its key, or fetch all
collctions if argument is None.
"""
raise NotImplementedError()

def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord[K]]:
def getParentChains(self, key: K) -> Iterator[ChainedCollectionRecord[K]]:
# Docstring inherited from CollectionManager.
table = self._tables.collection_chain
sql = (
Expand All @@ -419,11 +415,13 @@ def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord[K]]:
)
with self._db.query(sql) as sql_result:
parent_keys = sql_result.scalars().all()
for key in parent_keys:
# TODO: Just in case cached records miss new parent collections.
# This is temporary, will replace with non-cached records soon.
with contextlib.suppress(KeyError):
yield cast(ChainedCollectionRecord[K], self._records[key])
# TODO: It would be more efficient to write a single query that both
# finds parents and all their children, but for now we do not care
# much about efficiency. Also the only client of this method does not
# need full records, only parent collection names, maybe we should
# change this method to return names instead.
for record in self._fetch_by_key(parent_keys):
yield cast(ChainedCollectionRecord[K], record)

def update_chain(
self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False
Expand Down
115 changes: 108 additions & 7 deletions python/lsst/daf/butler/registry/collections/nameKey.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations

from ... import ddl

__all__ = ["NameKeyCollectionManager"]

from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING, Any

import sqlalchemy

from ... import ddl
from ..._timespan import TimespanDatabaseRepresentation
from ..interfaces import VersionTuple
from .._collection_type import CollectionType
from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple
from ._base import (
CollectionTablesTuple,
DefaultCollectionManager,
Expand All @@ -44,7 +45,7 @@
)

if TYPE_CHECKING:
from ..interfaces import CollectionRecord, Database, DimensionRecordStorageManager, StaticTablesContext
from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext


_KEY_FIELD_SPEC = ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, primaryKey=True)
Expand All @@ -68,7 +69,7 @@ def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) ->
)


class NameKeyCollectionManager(DefaultCollectionManager):
class NameKeyCollectionManager(DefaultCollectionManager[str]):
"""A `CollectionManager` implementation that uses collection names for
primary/foreign keys and aggressively loads all collection/run records in
the database into memory.
Expand Down Expand Up @@ -152,10 +153,110 @@ def addRunForeignKey(
)
return copy

def _getByName(self, name: str) -> CollectionRecord | None:
# Docstring inherited from DefaultCollectionManager.
def _get_cached_name(self, name: str) -> CollectionRecord[str] | None:
# Docstring inherited from base class.
return self._records.get(name)

def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]:
# Docstring inherited from base class.
return self._fetch_by_key(names)

def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]:
# Docstring inherited from base class.
sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from(
self._tables.collection.join(self._tables.run, isouter=True)
)

chain_sql = sqlalchemy.sql.select(
self._tables.collection_chain.columns["parent"],
self._tables.collection_chain.columns["position"],
self._tables.collection_chain.columns["child"],
)

records: list[CollectionRecord[str]] = []
# We want to keep transactions as short as possible. When we fetch
# everything we want to quickly fetch things into memory and finish
# transaction. When we fetch just few records we need to process result
# of the first query before we can run the second one.
if collection_ids is not None:
sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids))
with self._db.transaction():
with self._db.query(sql) as sql_result:
sql_rows = sql_result.mappings().fetchall()

records, chained_ids = self._rows_to_records(sql_rows)

if chained_ids:
# Retrieve chained collection compositions
chain_sql = chain_sql.where(
self._tables.collection_chain.columns["parent"].in_(chained_ids)
)
with self._db.query(chain_sql) as sql_result:
chain_rows = sql_result.mappings().fetchall()

records += self._rows_to_chains(chain_rows, chained_ids)

else:
with self._db.transaction():
with self._db.query(sql) as sql_result:
sql_rows = sql_result.mappings().fetchall()
with self._db.query(chain_sql) as sql_result:
chain_rows = sql_result.mappings().fetchall()

records, chained_ids = self._rows_to_records(sql_rows)
records += self._rows_to_chains(chain_rows, chained_ids)

return records

def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]:
"""Convert rows returned from collection query to a list of records
and a list chained collection names.
"""
records: list[CollectionRecord[str]] = []
TimespanReprClass = self._db.getTimespanRepresentation()
chained_ids: list[str] = []
for row in rows:
name = row[self._tables.collection.columns.name]
type = CollectionType(row["type"])
record: CollectionRecord[str]
if type is CollectionType.RUN:
record = RunRecord[str](
key=name,
name=name,
host=row[self._tables.run.columns.host],
timespan=TimespanReprClass.extract(row),
)
records.append(record)
elif type is CollectionType.CHAINED:
# Need to delay chained collection construction until to
# fetch their children names.
chained_ids.append(name)
else:
record = CollectionRecord[str](key=name, name=name, type=type)
records.append(record)

return records, chained_ids

def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]:
"""Convert rows returned from collection chain query to a list of
records.
"""
chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids}
for row in rows:
chains_defs[row["parent"]].append((row["position"], row["child"]))

records: list[CollectionRecord[str]] = []
for name, children in chains_defs.items():
children_names = [child for _, child in sorted(children)]
record = ChainedCollectionRecord[str](
key=name,
name=name,
children=children_names,
)
records.append(record)

return records

@classmethod
def currentVersions(cls) -> list[VersionTuple]:
# Docstring inherited from VersionedExtension.
Expand Down
Loading

0 comments on commit 8481829

Please sign in to comment.