diff --git a/python/lsst/daf/butler/registry/databases/postgresql.py b/python/lsst/daf/butler/registry/databases/postgresql.py index 30702bfee9..fcd542af30 100644 --- a/python/lsst/daf/butler/registry/databases/postgresql.py +++ b/python/lsst/daf/butler/registry/databases/postgresql.py @@ -44,6 +44,7 @@ from ..._named import NamedValueAbstractSet from ..._timespan import Timespan +from ..._utilities.locked_object import LockedObject from ...timespan_database_representation import TimespanDatabaseRepresentation from ..interfaces import Database @@ -124,7 +125,7 @@ def _init( namespace: str | None = None, writeable: bool = True, dbname: str, - metadata: sqlalchemy.schema.MetaData | None, + metadata: LockedObject[sqlalchemy.schema.MetaData] | None, pg_version: tuple[int, int], ) -> None: # Initialization logic shared between ``__init__`` and ``clone``. diff --git a/python/lsst/daf/butler/registry/databases/sqlite.py b/python/lsst/daf/butler/registry/databases/sqlite.py index 97450d05c8..82e268632d 100644 --- a/python/lsst/daf/butler/registry/databases/sqlite.py +++ b/python/lsst/daf/butler/registry/databases/sqlite.py @@ -42,6 +42,7 @@ from ... import ddl from ..._named import NamedValueAbstractSet +from ..._utilities.locked_object import LockedObject from ..interfaces import Database, StaticTablesContext @@ -109,7 +110,7 @@ def _init( namespace: str | None = None, writeable: bool = True, filename: str | None, - metadata: sqlalchemy.schema.MetaData | None, + metadata: LockedObject[sqlalchemy.schema.MetaData] | None, ) -> None: # Initialization logic shared between ``__init__`` and ``clone``. super().__init__(origin=origin, engine=engine, namespace=namespace, metadata=metadata) diff --git a/python/lsst/daf/butler/registry/interfaces/_database.py b/python/lsst/daf/butler/registry/interfaces/_database.py index 0e69a14cbb..d5ed7b4f95 100644 --- a/python/lsst/daf/butler/registry/interfaces/_database.py +++ b/python/lsst/daf/butler/registry/interfaces/_database.py @@ -52,6 +52,7 @@ import sqlalchemy from ..._named import NamedValueAbstractSet +from ..._utilities.locked_object import LockedObject from ...name_shrinker import NameShrinker from ...timespan_database_representation import TimespanDatabaseRepresentation from .._exceptions import ConflictingDefinitionError @@ -165,12 +166,15 @@ def addTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table: relationships. """ name = self._db._mangleTableName(name) - metadata = self._db._metadata - assert metadata is not None, "Guaranteed by context manager that returns this object." - table = self._db._convertTableSpec(name, spec, metadata) - for foreignKeySpec in spec.foreignKeys: - self._foreignKeys.append((table, self._db._convertForeignKeySpec(name, foreignKeySpec, metadata))) - return table + metadata_wrapper = self._db._metadata + assert metadata_wrapper is not None, "Guaranteed by context manager that returns this object." + with metadata_wrapper.access() as metadata: + table = self._db._convertTableSpec(name, spec, metadata) + for foreignKeySpec in spec.foreignKeys: + self._foreignKeys.append( + (table, self._db._convertForeignKeySpec(name, foreignKeySpec, metadata)) + ) + return table def addTableTuple(self, specs: tuple[ddl.TableSpec, ...]) -> tuple[sqlalchemy.schema.Table, ...]: """Add a named tuple of tables to the schema, returning their @@ -273,7 +277,7 @@ def __init__( origin: int, engine: sqlalchemy.engine.Engine, namespace: str | None = None, - metadata: sqlalchemy.schema.MetaData | None = None, + metadata: LockedObject[sqlalchemy.schema.MetaData] | None = None, ): self.origin = origin self.name_shrinker = NameShrinker(engine.dialect.max_identifier_length) @@ -760,7 +764,12 @@ def declareStaticTables(self, *, create: bool) -> Iterator[StaticTablesContext]: """ if create and not self.isWriteable(): raise ReadOnlyDatabaseError(f"Cannot create tables in read-only database {self}.") - self._metadata = sqlalchemy.MetaData(schema=self.namespace) + # sqlalchemy.MetaData is documented to be threadsafe for reads, but not + # with concurrent modifications. We add tables dynamically at runtime, + # and the MetaData object is shared by all Database instances sharing + # the same connection pool. So wrap all accesses to this object with a + # lock to avoid concurrency issues. + self._metadata = LockedObject(sqlalchemy.MetaData(schema=self.namespace)) try: with self._transaction() as (_, connection): context = StaticTablesContext(self, connection) @@ -786,7 +795,8 @@ def declareStaticTables(self, *, create: bool) -> Iterator[StaticTablesContext]: # deprecation warnings when tables are created. with warnings.catch_warnings(): warnings.simplefilter("ignore", category=sqlalchemy.exc.SADeprecationWarning) - self._metadata.create_all(connection) + with self._metadata.access() as metadata: + metadata.create_all(connection) # call all initializer methods sequentially for init in context._initializers: init(self) @@ -1141,9 +1151,10 @@ def ensureTableExists(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema raise ReadOnlyDatabaseError( f"Table {name} does not exist, and cannot be created because database {self} is read-only." ) - table = self._convertTableSpec(name, spec, self._metadata) - for foreignKeySpec in spec.foreignKeys: - table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, self._metadata)) + with self._metadata.access() as metadata: + table = self._convertTableSpec(name, spec, metadata) + for foreignKeySpec in spec.foreignKeys: + table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, metadata)) try: with self._transaction() as (_, connection): table.create(connection) @@ -1192,7 +1203,8 @@ def getExistingTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema. """ assert self._metadata is not None, "Static tables must be declared before dynamic tables." name = self._mangleTableName(name) - table = self._metadata.tables.get(name if self.namespace is None else f"{self.namespace}.{name}") + with self._metadata.access() as metadata: + table = metadata.tables.get(name if self.namespace is None else f"{self.namespace}.{name}") if table is not None: if spec.fields.names != set(table.columns.keys()): raise DatabaseConflictError( @@ -1206,9 +1218,10 @@ def getExistingTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema. ) if name in inspector.get_table_names(schema=self.namespace): _checkExistingTableDefinition(name, spec, inspector.get_columns(name, schema=self.namespace)) - table = self._convertTableSpec(name, spec, self._metadata) - for foreignKeySpec in spec.foreignKeys: - table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, self._metadata)) + with self._metadata.access() as metadata: + table = self._convertTableSpec(name, spec, metadata) + for foreignKeySpec in spec.foreignKeys: + table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, metadata)) return table return table @@ -1244,19 +1257,19 @@ def _make_temporary_table( """ if name is None: name = f"tmp_{uuid.uuid4().hex}" - metadata = self._metadata - if metadata is None: + if self._metadata is None: raise RuntimeError("Cannot create temporary table before static schema is defined.") - table = self._convertTableSpec( - name, spec, metadata, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA, **kwargs - ) - if table.key in self._temp_tables and table.key != name: - raise ValueError( - f"A temporary table with name {name} (transformed to {table.key} by " - "Database) already exists." + with self._metadata.access() as metadata: + table = self._convertTableSpec( + name, spec, metadata, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA, **kwargs ) - for foreignKeySpec in spec.foreignKeys: - table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, metadata)) + if table.key in self._temp_tables and table.key != name: + raise ValueError( + f"A temporary table with name {name} (transformed to {table.key} by " + "Database) already exists." + ) + for foreignKeySpec in spec.foreignKeys: + table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, metadata)) with self._transaction(): table.create(connection) return table