diff --git a/python/lsst/daf/butler/registry/databases/postgresql.py b/python/lsst/daf/butler/registry/databases/postgresql.py index 30702bfee9..cc60bac5b6 100644 --- a/python/lsst/daf/butler/registry/databases/postgresql.py +++ b/python/lsst/daf/butler/registry/databases/postgresql.py @@ -45,7 +45,7 @@ from ..._named import NamedValueAbstractSet from ..._timespan import Timespan from ...timespan_database_representation import TimespanDatabaseRepresentation -from ..interfaces import Database +from ..interfaces import Database, DatabaseMetadata class PostgresqlDatabase(Database): @@ -124,7 +124,7 @@ def _init( namespace: str | None = None, writeable: bool = True, dbname: str, - metadata: sqlalchemy.schema.MetaData | None, + metadata: DatabaseMetadata | None, pg_version: tuple[int, int], ) -> None: # Initialization logic shared between ``__init__`` and ``clone``. diff --git a/python/lsst/daf/butler/registry/databases/sqlite.py b/python/lsst/daf/butler/registry/databases/sqlite.py index 97450d05c8..fd3f2a4e0e 100644 --- a/python/lsst/daf/butler/registry/databases/sqlite.py +++ b/python/lsst/daf/butler/registry/databases/sqlite.py @@ -42,7 +42,7 @@ from ... import ddl from ..._named import NamedValueAbstractSet -from ..interfaces import Database, StaticTablesContext +from ..interfaces import Database, DatabaseMetadata, StaticTablesContext def _onSqlite3Connect( @@ -109,7 +109,7 @@ def _init( namespace: str | None = None, writeable: bool = True, filename: str | None, - metadata: sqlalchemy.schema.MetaData | None, + metadata: DatabaseMetadata | None, ) -> None: # Initialization logic shared between ``__init__`` and ``clone``. super().__init__(origin=origin, engine=engine, namespace=namespace, metadata=metadata) diff --git a/python/lsst/daf/butler/registry/interfaces/_database.py b/python/lsst/daf/butler/registry/interfaces/_database.py index 0e69a14cbb..059ad693d0 100644 --- a/python/lsst/daf/butler/registry/interfaces/_database.py +++ b/python/lsst/daf/butler/registry/interfaces/_database.py @@ -30,6 +30,7 @@ __all__ = [ "Database", + "DatabaseMetadata", "ReadOnlyDatabaseError", "DatabaseConflictError", "DatabaseInsertMode", @@ -46,6 +47,7 @@ from collections import defaultdict from collections.abc import Callable, Iterable, Iterator, Sequence from contextlib import contextmanager +from threading import Lock from typing import Any, cast, final import astropy.time @@ -136,7 +138,6 @@ class StaticTablesContext: def __init__(self, db: Database, connection: sqlalchemy.engine.Connection): self._db = db - self._foreignKeys: list[tuple[sqlalchemy.schema.Table, sqlalchemy.schema.ForeignKeyConstraint]] = [] self._inspector = sqlalchemy.inspect(connection) self._tableNames = frozenset(self._inspector.get_table_names(schema=self._db.namespace)) self._initializers: list[Callable[[Database], None]] = [] @@ -164,13 +165,9 @@ def addTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table: to be declared in any order even in the presence of foreign key relationships. """ - name = self._db._mangleTableName(name) metadata = self._db._metadata assert metadata is not None, "Guaranteed by context manager that returns this object." - table = self._db._convertTableSpec(name, spec, metadata) - for foreignKeySpec in spec.foreignKeys: - self._foreignKeys.append((table, self._db._convertForeignKeySpec(name, foreignKeySpec, metadata))) - return table + return metadata.add_table(self._db, name, spec) def addTableTuple(self, specs: tuple[ddl.TableSpec, ...]) -> tuple[sqlalchemy.schema.Table, ...]: """Add a named tuple of tables to the schema, returning their @@ -273,15 +270,15 @@ def __init__( origin: int, engine: sqlalchemy.engine.Engine, namespace: str | None = None, - metadata: sqlalchemy.schema.MetaData | None = None, + metadata: DatabaseMetadata | None = None, ): self.origin = origin self.name_shrinker = NameShrinker(engine.dialect.max_identifier_length) self.namespace = namespace self._engine = engine self._session_connection: sqlalchemy.engine.Connection | None = None - self._metadata = metadata self._temp_tables: set[str] = set() + self._metadata = metadata def __repr__(self) -> str: # Rather than try to reproduce all the parameters used to create @@ -540,6 +537,7 @@ def temporary_table( otherwise, but in that case they probably need to be modified to support the full range of expected read-only butler behavior. """ + assert self._metadata is not None, "Static tables must be created before temporary tables" with self._session() as connection: table = self._make_temporary_table(connection, spec=spec, name=name) self._temp_tables.add(table.key) @@ -549,6 +547,7 @@ def temporary_table( with self._transaction(): table.drop(connection) self._temp_tables.remove(table.key) + self._metadata.remove_table(table.name) @contextmanager def _session(self) -> Iterator[sqlalchemy.engine.Connection]: @@ -760,39 +759,30 @@ def declareStaticTables(self, *, create: bool) -> Iterator[StaticTablesContext]: """ if create and not self.isWriteable(): raise ReadOnlyDatabaseError(f"Cannot create tables in read-only database {self}.") - self._metadata = sqlalchemy.MetaData(schema=self.namespace) - try: - with self._transaction() as (_, connection): - context = StaticTablesContext(self, connection) - if create and context._tableNames: - # Looks like database is already initalized, to avoid - # danger of modifying/destroying valid schema we refuse to - # do anything in this case - raise SchemaAlreadyDefinedError(f"Cannot create tables in non-empty database {self}.") - yield context - for table, foreignKey in context._foreignKeys: - table.append_constraint(foreignKey) - if create: - if ( - self.namespace is not None - and self.namespace not in context._inspector.get_schema_names() - ): - connection.execute(sqlalchemy.schema.CreateSchema(self.namespace)) - # In our tables we have columns that make use of sqlalchemy - # Sequence objects. There is currently a bug in sqlalchemy - # that causes a deprecation warning to be thrown on a - # property of the Sequence object when the repr for the - # sequence is created. Here a filter is used to catch these - # deprecation warnings when tables are created. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=sqlalchemy.exc.SADeprecationWarning) - self._metadata.create_all(connection) - # call all initializer methods sequentially - for init in context._initializers: - init(self) - except BaseException: - self._metadata = None - raise + self._metadata = DatabaseMetadata(self.namespace) + with self._transaction() as (_, connection): + context = StaticTablesContext(self, connection) + if create and context._tableNames: + # Looks like database is already initalized, to avoid + # danger of modifying/destroying valid schema we refuse to + # do anything in this case + raise SchemaAlreadyDefinedError(f"Cannot create tables in non-empty database {self}.") + yield context + if create: + if self.namespace is not None and self.namespace not in context._inspector.get_schema_names(): + connection.execute(sqlalchemy.schema.CreateSchema(self.namespace)) + # In our tables we have columns that make use of sqlalchemy + # Sequence objects. There is currently a bug in sqlalchemy + # that causes a deprecation warning to be thrown on a + # property of the Sequence object when the repr for the + # sequence is created. Here a filter is used to catch these + # deprecation warnings when tables are created. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=sqlalchemy.exc.SADeprecationWarning) + self._metadata.create_all(connection) + # call all initializer methods sequentially + for init in context._initializers: + init(self) @abstractmethod def isWriteable(self) -> bool: @@ -1141,9 +1131,8 @@ def ensureTableExists(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema raise ReadOnlyDatabaseError( f"Table {name} does not exist, and cannot be created because database {self} is read-only." ) - table = self._convertTableSpec(name, spec, self._metadata) - for foreignKeySpec in spec.foreignKeys: - table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, self._metadata)) + + table = self._metadata.add_table(self, name, spec) try: with self._transaction() as (_, connection): table.create(connection) @@ -1192,7 +1181,7 @@ def getExistingTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema. """ assert self._metadata is not None, "Static tables must be declared before dynamic tables." name = self._mangleTableName(name) - table = self._metadata.tables.get(name if self.namespace is None else f"{self.namespace}.{name}") + table = self._metadata.get_table(name) if table is not None: if spec.fields.names != set(table.columns.keys()): raise DatabaseConflictError( @@ -1206,10 +1195,7 @@ def getExistingTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema. ) if name in inspector.get_table_names(schema=self.namespace): _checkExistingTableDefinition(name, spec, inspector.get_columns(name, schema=self.namespace)) - table = self._convertTableSpec(name, spec, self._metadata) - for foreignKeySpec in spec.foreignKeys: - table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, self._metadata)) - return table + return self._metadata.add_table(self, name, spec) return table def _make_temporary_table( @@ -1244,19 +1230,16 @@ def _make_temporary_table( """ if name is None: name = f"tmp_{uuid.uuid4().hex}" - metadata = self._metadata - if metadata is None: + if self._metadata is None: raise RuntimeError("Cannot create temporary table before static schema is defined.") - table = self._convertTableSpec( - name, spec, metadata, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA, **kwargs + table = self._metadata.add_table( + self, name, spec, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA, **kwargs ) if table.key in self._temp_tables and table.key != name: raise ValueError( f"A temporary table with name {name} (transformed to {table.key} by " "Database) already exists." ) - for foreignKeySpec in spec.foreignKeys: - table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, metadata)) with self._transaction(): table.create(connection) return table @@ -2010,3 +1993,103 @@ def apply_any_aggregate(self, column: sqlalchemy.ColumnElement[Any]) -> sqlalche """An object that can be used to shrink field names to fit within the identifier limit of the database engine (`NameShrinker`). """ + + +class DatabaseMetadata: + """Wrapper around SqlAlchemy MetaData object to ensure threadsafety. + + Parameters + ---------- + namespace : `str` or `None` + Name of the schema or namespace this instance is associated with. + + Notes + ----- + `sqlalchemy.MetaData` is documented to be threadsafe for reads, but not + with concurrent modifications. We add tables dynamically at runtime, + and the MetaData object is shared by all Database instances sharing + the same connection pool. + """ + + def __init__(self, namespace: str | None) -> None: + self._lock = Lock() + self._metadata = sqlalchemy.MetaData(schema=namespace) + self._tables: dict[str, sqlalchemy.Table] = {} + + def add_table( + self, db: Database, name: str, spec: ddl.TableSpec, **kwargs: Any + ) -> sqlalchemy.schema.Table: + """Add a new table to the MetaData object, returning its sqlalchemy + representation. This does not physically create the table in the + database -- it only sets up its definition. + + Parameters + ---------- + db : `Database` + Database connection associated with the table definition. + name : `str` + The name of the table. + spec : `ddl.TableSpec` + The specification of the table. + **kwargs + Additional keyword arguments to forward to the + `sqlalchemy.schema.Table` constructor. + + Returns + ------- + table : `sqlalchemy.schema.Table` + The created table. + """ + with self._lock: + if (table := self._tables.get(name)) is not None: + return table + + table = db._convertTableSpec(name, spec, self._metadata, **kwargs) + for foreignKeySpec in spec.foreignKeys: + table.append_constraint(db._convertForeignKeySpec(name, foreignKeySpec, self._metadata)) + + self._tables[name] = table + return table + + def get_table(self, name: str) -> sqlalchemy.schema.Table | None: + """Return the definition of a table that was previously added to this + MetaData object. + + Parameters + ---------- + name : `str` + Name of the table. + + Returns + ------- + table : `sqlalchemy.schema.Table` or `None` + The table definition, or `None` if the table is not known to this + MetaData instance. + """ + with self._lock: + return self._tables.get(name) + + def remove_table(self, name: str) -> None: + """Remove a table that was previously added to this MetaData object. + + Parameters + ---------- + name : `str` + Name of the table. + """ + with self._lock: + table = self._tables.pop(name, None) + if table is not None: + self._metadata.remove(table) + + def create_all(self, connection: sqlalchemy.engine.Connection) -> None: + """Create all tables known to this MetaData object in the database. + Same as `sqlalchemy.MetaData.create_all`. + + Parameters + ---------- + connection : `sqlalchemy.engine.connection` + Database connection that will be used to create tables. + """ + with self._lock: + self._metadata.create_all(connection) diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index 42b44eb397..2a6385f83f 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -67,7 +67,7 @@ def isEmptyDatabaseActuallyWriteable(database: SqliteDatabase) -> bool: "a", ddl.TableSpec(fields=[ddl.FieldSpec("b", dtype=sqlalchemy.Integer, primaryKey=True)]) ) # Drop created table so that schema remains empty. - database._metadata.drop_all(database._engine, tables=[table]) + database._metadata._metadata.drop_all(database._engine, tables=[table]) return True except Exception: return False