From 28f3e9e277809c2ddc8c95b6813bb1a68a0df8fb Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Tue, 10 Dec 2024 15:29:52 -0700 Subject: [PATCH] Fix threadsafety of sqlalchemy MetaData access Wrap all accesses to sqlalchemy.MetaData with a lock to avoid concurrency issues. sqlalchemy.MetaData is documented to be threadsafe for reads, but not with concurrent modifications. We add tables dynamically at runtime, and the MetaData object is shared by all Database instances sharing the same connection pool. Prior to adding the lock, Butler server database calls that added table definitions dynamically were sometimes failing with InvalidRequestError exceptions complaining about inconsistency of table definitions. --- .../butler/registry/databases/postgresql.py | 4 +- .../daf/butler/registry/databases/sqlite.py | 4 +- .../butler/registry/interfaces/_database.py | 193 +++++++++++++----- tests/test_sqlite.py | 2 +- 4 files changed, 143 insertions(+), 60 deletions(-) diff --git a/python/lsst/daf/butler/registry/databases/postgresql.py b/python/lsst/daf/butler/registry/databases/postgresql.py index 30702bfee9..cc60bac5b6 100644 --- a/python/lsst/daf/butler/registry/databases/postgresql.py +++ b/python/lsst/daf/butler/registry/databases/postgresql.py @@ -45,7 +45,7 @@ from ..._named import NamedValueAbstractSet from ..._timespan import Timespan from ...timespan_database_representation import TimespanDatabaseRepresentation -from ..interfaces import Database +from ..interfaces import Database, DatabaseMetadata class PostgresqlDatabase(Database): @@ -124,7 +124,7 @@ def _init( namespace: str | None = None, writeable: bool = True, dbname: str, - metadata: sqlalchemy.schema.MetaData | None, + metadata: DatabaseMetadata | None, pg_version: tuple[int, int], ) -> None: # Initialization logic shared between ``__init__`` and ``clone``. diff --git a/python/lsst/daf/butler/registry/databases/sqlite.py b/python/lsst/daf/butler/registry/databases/sqlite.py index 97450d05c8..fd3f2a4e0e 100644 --- a/python/lsst/daf/butler/registry/databases/sqlite.py +++ b/python/lsst/daf/butler/registry/databases/sqlite.py @@ -42,7 +42,7 @@ from ... import ddl from ..._named import NamedValueAbstractSet -from ..interfaces import Database, StaticTablesContext +from ..interfaces import Database, DatabaseMetadata, StaticTablesContext def _onSqlite3Connect( @@ -109,7 +109,7 @@ def _init( namespace: str | None = None, writeable: bool = True, filename: str | None, - metadata: sqlalchemy.schema.MetaData | None, + metadata: DatabaseMetadata | None, ) -> None: # Initialization logic shared between ``__init__`` and ``clone``. super().__init__(origin=origin, engine=engine, namespace=namespace, metadata=metadata) diff --git a/python/lsst/daf/butler/registry/interfaces/_database.py b/python/lsst/daf/butler/registry/interfaces/_database.py index 0e69a14cbb..059ad693d0 100644 --- a/python/lsst/daf/butler/registry/interfaces/_database.py +++ b/python/lsst/daf/butler/registry/interfaces/_database.py @@ -30,6 +30,7 @@ __all__ = [ "Database", + "DatabaseMetadata", "ReadOnlyDatabaseError", "DatabaseConflictError", "DatabaseInsertMode", @@ -46,6 +47,7 @@ from collections import defaultdict from collections.abc import Callable, Iterable, Iterator, Sequence from contextlib import contextmanager +from threading import Lock from typing import Any, cast, final import astropy.time @@ -136,7 +138,6 @@ class StaticTablesContext: def __init__(self, db: Database, connection: sqlalchemy.engine.Connection): self._db = db - self._foreignKeys: list[tuple[sqlalchemy.schema.Table, sqlalchemy.schema.ForeignKeyConstraint]] = [] self._inspector = sqlalchemy.inspect(connection) self._tableNames = frozenset(self._inspector.get_table_names(schema=self._db.namespace)) self._initializers: list[Callable[[Database], None]] = [] @@ -164,13 +165,9 @@ def addTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table: to be declared in any order even in the presence of foreign key relationships. """ - name = self._db._mangleTableName(name) metadata = self._db._metadata assert metadata is not None, "Guaranteed by context manager that returns this object." - table = self._db._convertTableSpec(name, spec, metadata) - for foreignKeySpec in spec.foreignKeys: - self._foreignKeys.append((table, self._db._convertForeignKeySpec(name, foreignKeySpec, metadata))) - return table + return metadata.add_table(self._db, name, spec) def addTableTuple(self, specs: tuple[ddl.TableSpec, ...]) -> tuple[sqlalchemy.schema.Table, ...]: """Add a named tuple of tables to the schema, returning their @@ -273,15 +270,15 @@ def __init__( origin: int, engine: sqlalchemy.engine.Engine, namespace: str | None = None, - metadata: sqlalchemy.schema.MetaData | None = None, + metadata: DatabaseMetadata | None = None, ): self.origin = origin self.name_shrinker = NameShrinker(engine.dialect.max_identifier_length) self.namespace = namespace self._engine = engine self._session_connection: sqlalchemy.engine.Connection | None = None - self._metadata = metadata self._temp_tables: set[str] = set() + self._metadata = metadata def __repr__(self) -> str: # Rather than try to reproduce all the parameters used to create @@ -540,6 +537,7 @@ def temporary_table( otherwise, but in that case they probably need to be modified to support the full range of expected read-only butler behavior. """ + assert self._metadata is not None, "Static tables must be created before temporary tables" with self._session() as connection: table = self._make_temporary_table(connection, spec=spec, name=name) self._temp_tables.add(table.key) @@ -549,6 +547,7 @@ def temporary_table( with self._transaction(): table.drop(connection) self._temp_tables.remove(table.key) + self._metadata.remove_table(table.name) @contextmanager def _session(self) -> Iterator[sqlalchemy.engine.Connection]: @@ -760,39 +759,30 @@ def declareStaticTables(self, *, create: bool) -> Iterator[StaticTablesContext]: """ if create and not self.isWriteable(): raise ReadOnlyDatabaseError(f"Cannot create tables in read-only database {self}.") - self._metadata = sqlalchemy.MetaData(schema=self.namespace) - try: - with self._transaction() as (_, connection): - context = StaticTablesContext(self, connection) - if create and context._tableNames: - # Looks like database is already initalized, to avoid - # danger of modifying/destroying valid schema we refuse to - # do anything in this case - raise SchemaAlreadyDefinedError(f"Cannot create tables in non-empty database {self}.") - yield context - for table, foreignKey in context._foreignKeys: - table.append_constraint(foreignKey) - if create: - if ( - self.namespace is not None - and self.namespace not in context._inspector.get_schema_names() - ): - connection.execute(sqlalchemy.schema.CreateSchema(self.namespace)) - # In our tables we have columns that make use of sqlalchemy - # Sequence objects. There is currently a bug in sqlalchemy - # that causes a deprecation warning to be thrown on a - # property of the Sequence object when the repr for the - # sequence is created. Here a filter is used to catch these - # deprecation warnings when tables are created. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=sqlalchemy.exc.SADeprecationWarning) - self._metadata.create_all(connection) - # call all initializer methods sequentially - for init in context._initializers: - init(self) - except BaseException: - self._metadata = None - raise + self._metadata = DatabaseMetadata(self.namespace) + with self._transaction() as (_, connection): + context = StaticTablesContext(self, connection) + if create and context._tableNames: + # Looks like database is already initalized, to avoid + # danger of modifying/destroying valid schema we refuse to + # do anything in this case + raise SchemaAlreadyDefinedError(f"Cannot create tables in non-empty database {self}.") + yield context + if create: + if self.namespace is not None and self.namespace not in context._inspector.get_schema_names(): + connection.execute(sqlalchemy.schema.CreateSchema(self.namespace)) + # In our tables we have columns that make use of sqlalchemy + # Sequence objects. There is currently a bug in sqlalchemy + # that causes a deprecation warning to be thrown on a + # property of the Sequence object when the repr for the + # sequence is created. Here a filter is used to catch these + # deprecation warnings when tables are created. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=sqlalchemy.exc.SADeprecationWarning) + self._metadata.create_all(connection) + # call all initializer methods sequentially + for init in context._initializers: + init(self) @abstractmethod def isWriteable(self) -> bool: @@ -1141,9 +1131,8 @@ def ensureTableExists(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema raise ReadOnlyDatabaseError( f"Table {name} does not exist, and cannot be created because database {self} is read-only." ) - table = self._convertTableSpec(name, spec, self._metadata) - for foreignKeySpec in spec.foreignKeys: - table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, self._metadata)) + + table = self._metadata.add_table(self, name, spec) try: with self._transaction() as (_, connection): table.create(connection) @@ -1192,7 +1181,7 @@ def getExistingTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema. """ assert self._metadata is not None, "Static tables must be declared before dynamic tables." name = self._mangleTableName(name) - table = self._metadata.tables.get(name if self.namespace is None else f"{self.namespace}.{name}") + table = self._metadata.get_table(name) if table is not None: if spec.fields.names != set(table.columns.keys()): raise DatabaseConflictError( @@ -1206,10 +1195,7 @@ def getExistingTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema. ) if name in inspector.get_table_names(schema=self.namespace): _checkExistingTableDefinition(name, spec, inspector.get_columns(name, schema=self.namespace)) - table = self._convertTableSpec(name, spec, self._metadata) - for foreignKeySpec in spec.foreignKeys: - table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, self._metadata)) - return table + return self._metadata.add_table(self, name, spec) return table def _make_temporary_table( @@ -1244,19 +1230,16 @@ def _make_temporary_table( """ if name is None: name = f"tmp_{uuid.uuid4().hex}" - metadata = self._metadata - if metadata is None: + if self._metadata is None: raise RuntimeError("Cannot create temporary table before static schema is defined.") - table = self._convertTableSpec( - name, spec, metadata, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA, **kwargs + table = self._metadata.add_table( + self, name, spec, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA, **kwargs ) if table.key in self._temp_tables and table.key != name: raise ValueError( f"A temporary table with name {name} (transformed to {table.key} by " "Database) already exists." ) - for foreignKeySpec in spec.foreignKeys: - table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, metadata)) with self._transaction(): table.create(connection) return table @@ -2010,3 +1993,103 @@ def apply_any_aggregate(self, column: sqlalchemy.ColumnElement[Any]) -> sqlalche """An object that can be used to shrink field names to fit within the identifier limit of the database engine (`NameShrinker`). """ + + +class DatabaseMetadata: + """Wrapper around SqlAlchemy MetaData object to ensure threadsafety. + + Parameters + ---------- + namespace : `str` or `None` + Name of the schema or namespace this instance is associated with. + + Notes + ----- + `sqlalchemy.MetaData` is documented to be threadsafe for reads, but not + with concurrent modifications. We add tables dynamically at runtime, + and the MetaData object is shared by all Database instances sharing + the same connection pool. + """ + + def __init__(self, namespace: str | None) -> None: + self._lock = Lock() + self._metadata = sqlalchemy.MetaData(schema=namespace) + self._tables: dict[str, sqlalchemy.Table] = {} + + def add_table( + self, db: Database, name: str, spec: ddl.TableSpec, **kwargs: Any + ) -> sqlalchemy.schema.Table: + """Add a new table to the MetaData object, returning its sqlalchemy + representation. This does not physically create the table in the + database -- it only sets up its definition. + + Parameters + ---------- + db : `Database` + Database connection associated with the table definition. + name : `str` + The name of the table. + spec : `ddl.TableSpec` + The specification of the table. + **kwargs + Additional keyword arguments to forward to the + `sqlalchemy.schema.Table` constructor. + + Returns + ------- + table : `sqlalchemy.schema.Table` + The created table. + """ + with self._lock: + if (table := self._tables.get(name)) is not None: + return table + + table = db._convertTableSpec(name, spec, self._metadata, **kwargs) + for foreignKeySpec in spec.foreignKeys: + table.append_constraint(db._convertForeignKeySpec(name, foreignKeySpec, self._metadata)) + + self._tables[name] = table + return table + + def get_table(self, name: str) -> sqlalchemy.schema.Table | None: + """Return the definition of a table that was previously added to this + MetaData object. + + Parameters + ---------- + name : `str` + Name of the table. + + Returns + ------- + table : `sqlalchemy.schema.Table` or `None` + The table definition, or `None` if the table is not known to this + MetaData instance. + """ + with self._lock: + return self._tables.get(name) + + def remove_table(self, name: str) -> None: + """Remove a table that was previously added to this MetaData object. + + Parameters + ---------- + name : `str` + Name of the table. + """ + with self._lock: + table = self._tables.pop(name, None) + if table is not None: + self._metadata.remove(table) + + def create_all(self, connection: sqlalchemy.engine.Connection) -> None: + """Create all tables known to this MetaData object in the database. + Same as `sqlalchemy.MetaData.create_all`. + + Parameters + ---------- + connection : `sqlalchemy.engine.connection` + Database connection that will be used to create tables. + """ + with self._lock: + self._metadata.create_all(connection) diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index 42b44eb397..2a6385f83f 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -67,7 +67,7 @@ def isEmptyDatabaseActuallyWriteable(database: SqliteDatabase) -> bool: "a", ddl.TableSpec(fields=[ddl.FieldSpec("b", dtype=sqlalchemy.Integer, primaryKey=True)]) ) # Drop created table so that schema remains empty. - database._metadata.drop_all(database._engine, tables=[table]) + database._metadata._metadata.drop_all(database._engine, tables=[table]) return True except Exception: return False