Skip to content

Commit

Permalink
Fix threadsafety of sqlalchemy MetaData access
Browse files Browse the repository at this point in the history
Wrap all accesses to sqlalchemy.MetaData with a lock to avoid concurrency issues.

sqlalchemy.MetaData is documented to be threadsafe for reads, but not with concurrent modifications.  We add tables dynamically at runtime, and the MetaData object is shared by all Database instances sharing the same connection pool.

Prior to adding the lock, Butler server database calls that added table definitions dynamically were sometimes failing with InvalidRequestError exceptions complaining about inconsistency of table definitions.
  • Loading branch information
dhirving committed Jan 2, 2025
1 parent b3f240b commit 28f3e9e
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 60 deletions.
4 changes: 2 additions & 2 deletions python/lsst/daf/butler/registry/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from ..._named import NamedValueAbstractSet
from ..._timespan import Timespan
from ...timespan_database_representation import TimespanDatabaseRepresentation
from ..interfaces import Database
from ..interfaces import Database, DatabaseMetadata


class PostgresqlDatabase(Database):
Expand Down Expand Up @@ -124,7 +124,7 @@ def _init(
namespace: str | None = None,
writeable: bool = True,
dbname: str,
metadata: sqlalchemy.schema.MetaData | None,
metadata: DatabaseMetadata | None,
pg_version: tuple[int, int],
) -> None:
# Initialization logic shared between ``__init__`` and ``clone``.
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/daf/butler/registry/databases/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

from ... import ddl
from ..._named import NamedValueAbstractSet
from ..interfaces import Database, StaticTablesContext
from ..interfaces import Database, DatabaseMetadata, StaticTablesContext


def _onSqlite3Connect(
Expand Down Expand Up @@ -109,7 +109,7 @@ def _init(
namespace: str | None = None,
writeable: bool = True,
filename: str | None,
metadata: sqlalchemy.schema.MetaData | None,
metadata: DatabaseMetadata | None,
) -> None:
# Initialization logic shared between ``__init__`` and ``clone``.
super().__init__(origin=origin, engine=engine, namespace=namespace, metadata=metadata)
Expand Down
193 changes: 138 additions & 55 deletions python/lsst/daf/butler/registry/interfaces/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

__all__ = [
"Database",
"DatabaseMetadata",
"ReadOnlyDatabaseError",
"DatabaseConflictError",
"DatabaseInsertMode",
Expand All @@ -46,6 +47,7 @@
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence
from contextlib import contextmanager
from threading import Lock
from typing import Any, cast, final

import astropy.time
Expand Down Expand Up @@ -136,7 +138,6 @@ class StaticTablesContext:

def __init__(self, db: Database, connection: sqlalchemy.engine.Connection):
self._db = db
self._foreignKeys: list[tuple[sqlalchemy.schema.Table, sqlalchemy.schema.ForeignKeyConstraint]] = []
self._inspector = sqlalchemy.inspect(connection)
self._tableNames = frozenset(self._inspector.get_table_names(schema=self._db.namespace))
self._initializers: list[Callable[[Database], None]] = []
Expand Down Expand Up @@ -164,13 +165,9 @@ def addTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table:
to be declared in any order even in the presence of foreign key
relationships.
"""
name = self._db._mangleTableName(name)
metadata = self._db._metadata
assert metadata is not None, "Guaranteed by context manager that returns this object."
table = self._db._convertTableSpec(name, spec, metadata)
for foreignKeySpec in spec.foreignKeys:
self._foreignKeys.append((table, self._db._convertForeignKeySpec(name, foreignKeySpec, metadata)))
return table
return metadata.add_table(self._db, name, spec)

def addTableTuple(self, specs: tuple[ddl.TableSpec, ...]) -> tuple[sqlalchemy.schema.Table, ...]:
"""Add a named tuple of tables to the schema, returning their
Expand Down Expand Up @@ -273,15 +270,15 @@ def __init__(
origin: int,
engine: sqlalchemy.engine.Engine,
namespace: str | None = None,
metadata: sqlalchemy.schema.MetaData | None = None,
metadata: DatabaseMetadata | None = None,
):
self.origin = origin
self.name_shrinker = NameShrinker(engine.dialect.max_identifier_length)
self.namespace = namespace
self._engine = engine
self._session_connection: sqlalchemy.engine.Connection | None = None
self._metadata = metadata
self._temp_tables: set[str] = set()
self._metadata = metadata

def __repr__(self) -> str:
# Rather than try to reproduce all the parameters used to create
Expand Down Expand Up @@ -540,6 +537,7 @@ def temporary_table(
otherwise, but in that case they probably need to be modified to
support the full range of expected read-only butler behavior.
"""
assert self._metadata is not None, "Static tables must be created before temporary tables"
with self._session() as connection:
table = self._make_temporary_table(connection, spec=spec, name=name)
self._temp_tables.add(table.key)
Expand All @@ -549,6 +547,7 @@ def temporary_table(
with self._transaction():
table.drop(connection)
self._temp_tables.remove(table.key)
self._metadata.remove_table(table.name)

@contextmanager
def _session(self) -> Iterator[sqlalchemy.engine.Connection]:
Expand Down Expand Up @@ -760,39 +759,30 @@ def declareStaticTables(self, *, create: bool) -> Iterator[StaticTablesContext]:
"""
if create and not self.isWriteable():
raise ReadOnlyDatabaseError(f"Cannot create tables in read-only database {self}.")
self._metadata = sqlalchemy.MetaData(schema=self.namespace)
try:
with self._transaction() as (_, connection):
context = StaticTablesContext(self, connection)
if create and context._tableNames:
# Looks like database is already initalized, to avoid
# danger of modifying/destroying valid schema we refuse to
# do anything in this case
raise SchemaAlreadyDefinedError(f"Cannot create tables in non-empty database {self}.")
yield context
for table, foreignKey in context._foreignKeys:
table.append_constraint(foreignKey)
if create:
if (
self.namespace is not None
and self.namespace not in context._inspector.get_schema_names()
):
connection.execute(sqlalchemy.schema.CreateSchema(self.namespace))
# In our tables we have columns that make use of sqlalchemy
# Sequence objects. There is currently a bug in sqlalchemy
# that causes a deprecation warning to be thrown on a
# property of the Sequence object when the repr for the
# sequence is created. Here a filter is used to catch these
# deprecation warnings when tables are created.
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=sqlalchemy.exc.SADeprecationWarning)
self._metadata.create_all(connection)
# call all initializer methods sequentially
for init in context._initializers:
init(self)
except BaseException:
self._metadata = None
raise
self._metadata = DatabaseMetadata(self.namespace)
with self._transaction() as (_, connection):
context = StaticTablesContext(self, connection)
if create and context._tableNames:
# Looks like database is already initalized, to avoid
# danger of modifying/destroying valid schema we refuse to
# do anything in this case
raise SchemaAlreadyDefinedError(f"Cannot create tables in non-empty database {self}.")
yield context
if create:
if self.namespace is not None and self.namespace not in context._inspector.get_schema_names():
connection.execute(sqlalchemy.schema.CreateSchema(self.namespace))
# In our tables we have columns that make use of sqlalchemy
# Sequence objects. There is currently a bug in sqlalchemy
# that causes a deprecation warning to be thrown on a
# property of the Sequence object when the repr for the
# sequence is created. Here a filter is used to catch these
# deprecation warnings when tables are created.
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=sqlalchemy.exc.SADeprecationWarning)
self._metadata.create_all(connection)
# call all initializer methods sequentially
for init in context._initializers:
init(self)

@abstractmethod
def isWriteable(self) -> bool:
Expand Down Expand Up @@ -1141,9 +1131,8 @@ def ensureTableExists(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema
raise ReadOnlyDatabaseError(
f"Table {name} does not exist, and cannot be created because database {self} is read-only."
)
table = self._convertTableSpec(name, spec, self._metadata)
for foreignKeySpec in spec.foreignKeys:
table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, self._metadata))

table = self._metadata.add_table(self, name, spec)
try:
with self._transaction() as (_, connection):
table.create(connection)
Expand Down Expand Up @@ -1192,7 +1181,7 @@ def getExistingTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.
"""
assert self._metadata is not None, "Static tables must be declared before dynamic tables."
name = self._mangleTableName(name)
table = self._metadata.tables.get(name if self.namespace is None else f"{self.namespace}.{name}")
table = self._metadata.get_table(name)
if table is not None:
if spec.fields.names != set(table.columns.keys()):
raise DatabaseConflictError(
Expand All @@ -1206,10 +1195,7 @@ def getExistingTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.
)
if name in inspector.get_table_names(schema=self.namespace):
_checkExistingTableDefinition(name, spec, inspector.get_columns(name, schema=self.namespace))
table = self._convertTableSpec(name, spec, self._metadata)
for foreignKeySpec in spec.foreignKeys:
table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, self._metadata))
return table
return self._metadata.add_table(self, name, spec)
return table

def _make_temporary_table(
Expand Down Expand Up @@ -1244,19 +1230,16 @@ def _make_temporary_table(
"""
if name is None:
name = f"tmp_{uuid.uuid4().hex}"
metadata = self._metadata
if metadata is None:
if self._metadata is None:
raise RuntimeError("Cannot create temporary table before static schema is defined.")
table = self._convertTableSpec(
name, spec, metadata, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA, **kwargs
table = self._metadata.add_table(
self, name, spec, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA, **kwargs
)
if table.key in self._temp_tables and table.key != name:
raise ValueError(
f"A temporary table with name {name} (transformed to {table.key} by "
"Database) already exists."
)
for foreignKeySpec in spec.foreignKeys:
table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, metadata))
with self._transaction():
table.create(connection)
return table
Expand Down Expand Up @@ -2010,3 +1993,103 @@ def apply_any_aggregate(self, column: sqlalchemy.ColumnElement[Any]) -> sqlalche
"""An object that can be used to shrink field names to fit within the
identifier limit of the database engine (`NameShrinker`).
"""


class DatabaseMetadata:
"""Wrapper around SqlAlchemy MetaData object to ensure threadsafety.
Parameters
----------
namespace : `str` or `None`
Name of the schema or namespace this instance is associated with.
Notes
-----
`sqlalchemy.MetaData` is documented to be threadsafe for reads, but not
with concurrent modifications. We add tables dynamically at runtime,
and the MetaData object is shared by all Database instances sharing
the same connection pool.
"""

def __init__(self, namespace: str | None) -> None:
self._lock = Lock()
self._metadata = sqlalchemy.MetaData(schema=namespace)
self._tables: dict[str, sqlalchemy.Table] = {}

def add_table(
self, db: Database, name: str, spec: ddl.TableSpec, **kwargs: Any
) -> sqlalchemy.schema.Table:
"""Add a new table to the MetaData object, returning its sqlalchemy
representation. This does not physically create the table in the
database -- it only sets up its definition.
Parameters
----------
db : `Database`
Database connection associated with the table definition.
name : `str`
The name of the table.
spec : `ddl.TableSpec`
The specification of the table.
**kwargs
Additional keyword arguments to forward to the
`sqlalchemy.schema.Table` constructor.
Returns
-------
table : `sqlalchemy.schema.Table`
The created table.
"""
with self._lock:
if (table := self._tables.get(name)) is not None:
return table

table = db._convertTableSpec(name, spec, self._metadata, **kwargs)
for foreignKeySpec in spec.foreignKeys:
table.append_constraint(db._convertForeignKeySpec(name, foreignKeySpec, self._metadata))

self._tables[name] = table
return table

def get_table(self, name: str) -> sqlalchemy.schema.Table | None:
"""Return the definition of a table that was previously added to this
MetaData object.
Parameters
----------
name : `str`
Name of the table.
Returns
-------
table : `sqlalchemy.schema.Table` or `None`
The table definition, or `None` if the table is not known to this
MetaData instance.
"""
with self._lock:
return self._tables.get(name)

def remove_table(self, name: str) -> None:
"""Remove a table that was previously added to this MetaData object.
Parameters
----------
name : `str`
Name of the table.
"""
with self._lock:
table = self._tables.pop(name, None)
if table is not None:
self._metadata.remove(table)

def create_all(self, connection: sqlalchemy.engine.Connection) -> None:
"""Create all tables known to this MetaData object in the database.
Same as `sqlalchemy.MetaData.create_all`.
Parameters
----------
connection : `sqlalchemy.engine.connection`
Database connection that will be used to create tables.
"""
with self._lock:
self._metadata.create_all(connection)
2 changes: 1 addition & 1 deletion tests/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def isEmptyDatabaseActuallyWriteable(database: SqliteDatabase) -> bool:
"a", ddl.TableSpec(fields=[ddl.FieldSpec("b", dtype=sqlalchemy.Integer, primaryKey=True)])
)
# Drop created table so that schema remains empty.
database._metadata.drop_all(database._engine, tables=[table])
database._metadata._metadata.drop_all(database._engine, tables=[table])
return True
except Exception:
return False
Expand Down

0 comments on commit 28f3e9e

Please sign in to comment.