Skip to content

Commit

Permalink
Merge pull request #1132 from lsst/tickets/DM-47770
Browse files Browse the repository at this point in the history
DM-47770: Fix threadsafety of sqlalchemy MetaData access
  • Loading branch information
dhirving authored Jan 6, 2025
2 parents 1ba37b0 + 952a2f8 commit eaace2a
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 64 deletions.
4 changes: 2 additions & 2 deletions python/lsst/daf/butler/registry/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from ..._named import NamedValueAbstractSet
from ..._timespan import Timespan
from ...timespan_database_representation import TimespanDatabaseRepresentation
from ..interfaces import Database
from ..interfaces import Database, DatabaseMetadata


class PostgresqlDatabase(Database):
Expand Down Expand Up @@ -124,7 +124,7 @@ def _init(
namespace: str | None = None,
writeable: bool = True,
dbname: str,
metadata: sqlalchemy.schema.MetaData | None,
metadata: DatabaseMetadata | None,
pg_version: tuple[int, int],
) -> None:
# Initialization logic shared between ``__init__`` and ``clone``.
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/daf/butler/registry/databases/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

from ... import ddl
from ..._named import NamedValueAbstractSet
from ..interfaces import Database, StaticTablesContext
from ..interfaces import Database, DatabaseMetadata, StaticTablesContext


def _onSqlite3Connect(
Expand Down Expand Up @@ -109,7 +109,7 @@ def _init(
namespace: str | None = None,
writeable: bool = True,
filename: str | None,
metadata: sqlalchemy.schema.MetaData | None,
metadata: DatabaseMetadata | None,
) -> None:
# Initialization logic shared between ``__init__`` and ``clone``.
super().__init__(origin=origin, engine=engine, namespace=namespace, metadata=metadata)
Expand Down
173 changes: 118 additions & 55 deletions python/lsst/daf/butler/registry/interfaces/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

__all__ = [
"Database",
"DatabaseMetadata",
"ReadOnlyDatabaseError",
"DatabaseConflictError",
"DatabaseInsertMode",
Expand All @@ -46,6 +47,7 @@
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence
from contextlib import contextmanager
from threading import Lock
from typing import Any, cast, final

import astropy.time
Expand Down Expand Up @@ -136,7 +138,6 @@ class StaticTablesContext:

def __init__(self, db: Database, connection: sqlalchemy.engine.Connection):
self._db = db
self._foreignKeys: list[tuple[sqlalchemy.schema.Table, sqlalchemy.schema.ForeignKeyConstraint]] = []
self._inspector = sqlalchemy.inspect(connection)
self._tableNames = frozenset(self._inspector.get_table_names(schema=self._db.namespace))
self._initializers: list[Callable[[Database], None]] = []
Expand Down Expand Up @@ -164,13 +165,9 @@ def addTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table:
to be declared in any order even in the presence of foreign key
relationships.
"""
name = self._db._mangleTableName(name)
metadata = self._db._metadata
assert metadata is not None, "Guaranteed by context manager that returns this object."
table = self._db._convertTableSpec(name, spec, metadata)
for foreignKeySpec in spec.foreignKeys:
self._foreignKeys.append((table, self._db._convertForeignKeySpec(name, foreignKeySpec, metadata)))
return table
return metadata.add_table(self._db, name, spec)

def addTableTuple(self, specs: tuple[ddl.TableSpec, ...]) -> tuple[sqlalchemy.schema.Table, ...]:
"""Add a named tuple of tables to the schema, returning their
Expand Down Expand Up @@ -273,15 +270,15 @@ def __init__(
origin: int,
engine: sqlalchemy.engine.Engine,
namespace: str | None = None,
metadata: sqlalchemy.schema.MetaData | None = None,
metadata: DatabaseMetadata | None = None,
):
self.origin = origin
self.name_shrinker = NameShrinker(engine.dialect.max_identifier_length)
self.namespace = namespace
self._engine = engine
self._session_connection: sqlalchemy.engine.Connection | None = None
self._metadata = metadata
self._temp_tables: set[str] = set()
self._metadata = metadata

def __repr__(self) -> str:
# Rather than try to reproduce all the parameters used to create
Expand Down Expand Up @@ -540,6 +537,7 @@ def temporary_table(
otherwise, but in that case they probably need to be modified to
support the full range of expected read-only butler behavior.
"""
assert self._metadata is not None, "Static tables must be created before temporary tables"
with self._session() as connection:
table = self._make_temporary_table(connection, spec=spec, name=name)
self._temp_tables.add(table.key)
Expand All @@ -549,6 +547,7 @@ def temporary_table(
with self._transaction():
table.drop(connection)
self._temp_tables.remove(table.key)
self._metadata.remove_table(table.name)

@contextmanager
def _session(self) -> Iterator[sqlalchemy.engine.Connection]:
Expand Down Expand Up @@ -760,7 +759,8 @@ def declareStaticTables(self, *, create: bool) -> Iterator[StaticTablesContext]:
"""
if create and not self.isWriteable():
raise ReadOnlyDatabaseError(f"Cannot create tables in read-only database {self}.")
self._metadata = sqlalchemy.MetaData(schema=self.namespace)

self._metadata = DatabaseMetadata(self.namespace)
try:
with self._transaction() as (_, connection):
context = StaticTablesContext(self, connection)
Expand All @@ -770,8 +770,6 @@ def declareStaticTables(self, *, create: bool) -> Iterator[StaticTablesContext]:
# do anything in this case
raise SchemaAlreadyDefinedError(f"Cannot create tables in non-empty database {self}.")
yield context
for table, foreignKey in context._foreignKeys:
table.append_constraint(foreignKey)
if create:
if (
self.namespace is not None
Expand Down Expand Up @@ -858,30 +856,6 @@ def expandDatabaseEntityName(self, shrunk: str) -> str:
"""
return shrunk

def _mangleTableName(self, name: str) -> str:
"""Map a logical, user-visible table name to the true table name used
in the database.
The default implementation returns the given name unchanged.
Parameters
----------
name : `str`
Input table name. Should not include a namespace (i.e. schema)
prefix.
Returns
-------
mangled : `str`
Mangled version of the table name (still with no namespace prefix).
Notes
-----
Reimplementations of this method must be idempotent - mangling an
already-mangled name must have no effect.
"""
return name

def _makeColumnConstraints(self, table: str, spec: ddl.FieldSpec) -> list[sqlalchemy.CheckConstraint]:
"""Create constraints based on this spec.
Expand Down Expand Up @@ -974,13 +948,11 @@ def _convertForeignKeySpec(
SQLAlchemy representation of the constraint.
"""
name = self.shrinkDatabaseEntityName(
"_".join(
["fkey", table, self._mangleTableName(spec.table)] + list(spec.target) + list(spec.source)
)
"_".join(["fkey", table, spec.table] + list(spec.target) + list(spec.source))
)
return sqlalchemy.schema.ForeignKeyConstraint(
spec.source,
[f"{self._mangleTableName(spec.table)}.{col}" for col in spec.target],
[f"{spec.table}.{col}" for col in spec.target],
name=name,
ondelete=spec.onDelete,
)
Expand Down Expand Up @@ -1050,7 +1022,6 @@ def _convertTableSpec(
avoid circular dependencies. These are added by higher-level logic in
`ensureTableExists`, `getExistingTable`, and `declareStaticTables`.
"""
name = self._mangleTableName(name)
args: list[sqlalchemy.schema.SchemaItem] = [
self._convertFieldSpec(name, fieldSpec, metadata) for fieldSpec in spec.fields
]
Expand Down Expand Up @@ -1141,9 +1112,8 @@ def ensureTableExists(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema
raise ReadOnlyDatabaseError(
f"Table {name} does not exist, and cannot be created because database {self} is read-only."
)
table = self._convertTableSpec(name, spec, self._metadata)
for foreignKeySpec in spec.foreignKeys:
table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, self._metadata))

table = self._metadata.add_table(self, name, spec)
try:
with self._transaction() as (_, connection):
table.create(connection)
Expand Down Expand Up @@ -1191,8 +1161,7 @@ def getExistingTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.
Subclasses may override this method, but usually should not need to.
"""
assert self._metadata is not None, "Static tables must be declared before dynamic tables."
name = self._mangleTableName(name)
table = self._metadata.tables.get(name if self.namespace is None else f"{self.namespace}.{name}")
table = self._metadata.get_table(name)
if table is not None:
if spec.fields.names != set(table.columns.keys()):
raise DatabaseConflictError(
Expand All @@ -1206,10 +1175,7 @@ def getExistingTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.
)
if name in inspector.get_table_names(schema=self.namespace):
_checkExistingTableDefinition(name, spec, inspector.get_columns(name, schema=self.namespace))
table = self._convertTableSpec(name, spec, self._metadata)
for foreignKeySpec in spec.foreignKeys:
table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, self._metadata))
return table
return self._metadata.add_table(self, name, spec)
return table

def _make_temporary_table(
Expand Down Expand Up @@ -1244,19 +1210,16 @@ def _make_temporary_table(
"""
if name is None:
name = f"tmp_{uuid.uuid4().hex}"
metadata = self._metadata
if metadata is None:
if self._metadata is None:
raise RuntimeError("Cannot create temporary table before static schema is defined.")
table = self._convertTableSpec(
name, spec, metadata, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA, **kwargs
table = self._metadata.add_table(
self, name, spec, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA, **kwargs
)
if table.key in self._temp_tables and table.key != name:
raise ValueError(
f"A temporary table with name {name} (transformed to {table.key} by "
"Database) already exists."
)
for foreignKeySpec in spec.foreignKeys:
table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, metadata))
with self._transaction():
table.create(connection)
return table
Expand Down Expand Up @@ -2010,3 +1973,103 @@ def apply_any_aggregate(self, column: sqlalchemy.ColumnElement[Any]) -> sqlalche
"""An object that can be used to shrink field names to fit within the
identifier limit of the database engine (`NameShrinker`).
"""


class DatabaseMetadata:
"""Wrapper around SqlAlchemy MetaData object to ensure threadsafety.
Parameters
----------
namespace : `str` or `None`
Name of the schema or namespace this instance is associated with.
Notes
-----
`sqlalchemy.MetaData` is documented to be threadsafe for reads, but not
with concurrent modifications. We add tables dynamically at runtime,
and the MetaData object is shared by all Database instances sharing
the same connection pool.
"""

def __init__(self, namespace: str | None) -> None:
self._lock = Lock()
self._metadata = sqlalchemy.MetaData(schema=namespace)
self._tables: dict[str, sqlalchemy.Table] = {}

def add_table(
self, db: Database, name: str, spec: ddl.TableSpec, **kwargs: Any
) -> sqlalchemy.schema.Table:
"""Add a new table to the MetaData object, returning its sqlalchemy
representation. This does not physically create the table in the
database -- it only sets up its definition.
Parameters
----------
db : `Database`
Database connection associated with the table definition.
name : `str`
The name of the table.
spec : `ddl.TableSpec`
The specification of the table.
**kwargs
Additional keyword arguments to forward to the
`sqlalchemy.schema.Table` constructor.
Returns
-------
table : `sqlalchemy.schema.Table`
The created table.
"""
with self._lock:
if (table := self._tables.get(name)) is not None:
return table

table = db._convertTableSpec(name, spec, self._metadata, **kwargs)
for foreignKeySpec in spec.foreignKeys:
table.append_constraint(db._convertForeignKeySpec(name, foreignKeySpec, self._metadata))

self._tables[name] = table
return table

def get_table(self, name: str) -> sqlalchemy.schema.Table | None:
"""Return the definition of a table that was previously added to this
MetaData object.
Parameters
----------
name : `str`
Name of the table.
Returns
-------
table : `sqlalchemy.schema.Table` or `None`
The table definition, or `None` if the table is not known to this
MetaData instance.
"""
with self._lock:
return self._tables.get(name)

def remove_table(self, name: str) -> None:
"""Remove a table that was previously added to this MetaData object.
Parameters
----------
name : `str`
Name of the table.
"""
with self._lock:
table = self._tables.pop(name, None)
if table is not None:
self._metadata.remove(table)

def create_all(self, connection: sqlalchemy.engine.Connection) -> None:
"""Create all tables known to this MetaData object in the database.
Same as `sqlalchemy.MetaData.create_all`.
Parameters
----------
connection : `sqlalchemy.engine.connection`
Database connection that will be used to create tables.
"""
with self._lock:
self._metadata.create_all(connection)
10 changes: 5 additions & 5 deletions tests/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def isEmptyDatabaseActuallyWriteable(database: SqliteDatabase) -> bool:
"a", ddl.TableSpec(fields=[ddl.FieldSpec("b", dtype=sqlalchemy.Integer, primaryKey=True)])
)
# Drop created table so that schema remains empty.
database._metadata.drop_all(database._engine, tables=[table])
database._metadata._metadata.drop_all(database._engine, tables=[table])
return True
except Exception:
return False
Expand Down Expand Up @@ -103,13 +103,13 @@ def testConnection(self):
_, filename = tempfile.mkstemp(dir=self.root, suffix=".sqlite3")
# Create a read-write database by passing in the filename.
rwFromFilename = SqliteDatabase.fromEngine(SqliteDatabase.makeEngine(filename=filename), origin=0)
self.assertEqual(rwFromFilename.filename, filename)
self.assertEqual(os.path.realpath(rwFromFilename.filename), os.path.realpath(filename))
self.assertEqual(rwFromFilename.origin, 0)
self.assertTrue(rwFromFilename.isWriteable())
self.assertTrue(isEmptyDatabaseActuallyWriteable(rwFromFilename))
# Create a read-write database via a URI.
rwFromUri = SqliteDatabase.fromUri(f"sqlite:///{filename}", origin=0)
self.assertEqual(rwFromUri.filename, filename)
self.assertEqual(os.path.realpath(rwFromUri.filename), os.path.realpath(filename))
self.assertEqual(rwFromUri.origin, 0)
self.assertTrue(rwFromUri.isWriteable())
self.assertTrue(isEmptyDatabaseActuallyWriteable(rwFromUri))
Expand All @@ -123,13 +123,13 @@ def testConnection(self):
roFromFilename = SqliteDatabase.fromEngine(
SqliteDatabase.makeEngine(filename=filename), origin=0, writeable=False
)
self.assertEqual(roFromFilename.filename, filename)
self.assertEqual(os.path.realpath(roFromFilename.filename), os.path.realpath(filename))
self.assertEqual(roFromFilename.origin, 0)
self.assertFalse(roFromFilename.isWriteable())
self.assertFalse(isEmptyDatabaseActuallyWriteable(roFromFilename))
# Create a read-write database via a URI.
roFromUri = SqliteDatabase.fromUri(f"sqlite:///{filename}", origin=0, writeable=False)
self.assertEqual(roFromUri.filename, filename)
self.assertEqual(os.path.realpath(roFromUri.filename), os.path.realpath(filename))
self.assertEqual(roFromUri.origin, 0)
self.assertFalse(roFromUri.isWriteable())
self.assertFalse(isEmptyDatabaseActuallyWriteable(roFromUri))
Expand Down

0 comments on commit eaace2a

Please sign in to comment.