Skip to content

Commit

Permalink
Add locking to more safely delete state groups
Browse files Browse the repository at this point in the history
Currently we don't really have anything that stops us from deleting
state groups when an in-flight event references it. This is a fairly
rare race currently, but we want to be able to more aggresively delete
state groups so it is important to address this to ensure that the
database remains valid.

See the class docstring of the new data store for an explanation for how
this works.
  • Loading branch information
erikjohnston committed Jan 29, 2025
1 parent 95a85b1 commit d5e89a3
Show file tree
Hide file tree
Showing 12 changed files with 1,039 additions and 46 deletions.
18 changes: 14 additions & 4 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class FederationEventHandler:
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
self._state_store = hs.get_datastores().state
self._state_deletion_store = hs.get_datastores().state_deletion
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state

Expand Down Expand Up @@ -580,7 +582,9 @@ async def process_remote_join(
room_version.identifier,
state_maps_to_resolve,
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
)
else:
Expand Down Expand Up @@ -1179,7 +1183,9 @@ async def _compute_event_context_with_maybe_missing_prevs(
room_version,
state_maps,
event_map={event_id: event},
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)

except Exception as e:
Expand Down Expand Up @@ -1874,7 +1880,9 @@ async def _check_event_auth(
room_version,
[local_state_id_map, claimed_auth_events_id_map],
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
)
else:
Expand Down Expand Up @@ -2014,7 +2022,9 @@ async def _check_for_soft_fail(
room_version,
state_sets,
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
)
else:
Expand Down
61 changes: 55 additions & 6 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
from synapse.util.stringutils import shortstr

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.controllers import StateStorageController
from synapse.storage.databases.main import DataStore
from synapse.storage.databases.state.deletion import StateDeletionDataStore

logger = logging.getLogger(__name__)
metrics_logger = logging.getLogger("synapse.state.metrics")
Expand Down Expand Up @@ -194,6 +196,8 @@ def __init__(self, hs: "HomeServer"):
self._storage_controllers = hs.get_storage_controllers()
self._events_shard_config = hs.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self._state_store = hs.get_datastores().state
self._state_deletion_store = hs.get_datastores().state_deletion

self._update_current_state_client = (
ReplicationUpdateCurrentStateRestServlet.make_client(hs)
Expand Down Expand Up @@ -475,7 +479,10 @@ async def compute_event_context(
@trace
@measure_func()
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
self,
room_id: str,
event_ids: StrCollection,
await_full_state: bool = True,
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
Expand Down Expand Up @@ -511,6 +518,17 @@ async def resolve_state_groups_for_events(
) = await self._state_storage_controller.get_state_group_delta(
state_group_id
)

if prev_group:
# Ensure that we still have the prev group, and ensure we don't
# delete it while we're persisting the event.
missing_state_group = await self._state_deletion_store.check_state_groups_and_bump_deletion(
{prev_group}
)
if missing_state_group:
prev_group = None
delta_ids = None

return _StateCacheEntry(
state=None,
state_group=state_group_id,
Expand All @@ -531,7 +549,9 @@ async def resolve_state_groups_for_events(
room_version,
state_to_resolve,
None,
state_res_store=StateResolutionStore(self.store),
state_res_store=StateResolutionStore(
self.store, self._state_deletion_store
),
)
return result

Expand Down Expand Up @@ -663,14 +683,42 @@ async def resolve_state_groups(
async with self.resolve_linearizer.queue(group_names):
cache = self._state_cache.get(group_names, None)
if cache:
return cache
# Check that the returned cache entry doesn't point to deleted
# state groups.
state_groups_to_check = set()
if cache.state_group is not None:
state_groups_to_check.add(cache.state_group)

if cache.prev_group is not None:
state_groups_to_check.add(cache.prev_group)

missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion(
state_groups_to_check
)

if not missing_state_groups:
return cache
else:
# There are missing state groups, so let's remove the stale
# entry and continue as if it was a cache miss.
self._state_cache.pop(group_names, None)

logger.info(
"Resolving state for %s with groups %s",
room_id,
list(group_names),
)

# We double check that none of the state groups have been deleted.
# They shouldn't be as all these state groups should be referenced.
missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion(
group_names
)
if missing_state_groups:
raise Exception(
f"State groups have been deleted: {shortstr(missing_state_groups)}"
)

state_groups_histogram.observe(len(state_groups_ids))

new_state = await self.resolve_events_with_store(
Expand Down Expand Up @@ -884,7 +932,8 @@ class StateResolutionStore:
in well defined way.
"""

store: "DataStore"
main_store: "DataStore"
state_deletion_store: "StateDeletionDataStore"

def get_events(
self, event_ids: StrCollection, allow_rejected: bool = False
Expand All @@ -899,7 +948,7 @@ def get_events(
An awaitable which resolves to a dict from event_id to event.
"""

return self.store.get_events(
return self.main_store.get_events(
event_ids,
redact_behaviour=EventRedactBehaviour.as_is,
get_prev_content=False,
Expand All @@ -920,4 +969,4 @@ def get_auth_chain_difference(
An awaitable that resolves to a set of event IDs.
"""

return self.store.get_auth_chain_difference(room_id, state_sets)
return self.main_store.get_auth_chain_difference(room_id, state_sets)
32 changes: 21 additions & 11 deletions synapse/storage/controllers/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def __init__(
# store for now.
self.main_store = stores.main
self.state_store = stores.state
self._state_deletion_store = stores.state_deletion

assert stores.persist_events
self.persist_events_store = stores.persist_events
Expand Down Expand Up @@ -549,7 +550,9 @@ async def _calculate_current_state(self, room_id: str) -> StateMap[str]:
room_version,
state_maps_by_state_group,
event_map=None,
state_res_store=StateResolutionStore(self.main_store),
state_res_store=StateResolutionStore(
self.main_store, self._state_deletion_store
),
)

return await res.get_state(self._state_controller, StateFilter.all())
Expand Down Expand Up @@ -635,15 +638,20 @@ async def _persist_event_batch(
room_id, [e for e, _ in chunk]
)

await self.persist_events_store._persist_events_and_state_updates(
room_id,
chunk,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
new_event_links=new_event_links,
)
# Stop the state groups from being deleted while we're persisting
# them.
async with self._state_deletion_store.persisting_state_group_references(
events_and_contexts
):
await self.persist_events_store._persist_events_and_state_updates(
room_id,
chunk,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
new_event_links=new_event_links,
)

return replaced_events

Expand Down Expand Up @@ -965,7 +973,9 @@ async def _get_new_state_after_events(
room_version,
state_groups,
events_map,
state_res_store=StateResolutionStore(self.main_store),
state_res_store=StateResolutionStore(
self.main_store, self._state_deletion_store
),
)

state_resolutions_during_persistence.inc()
Expand Down
10 changes: 8 additions & 2 deletions synapse/storage/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.databases.state import StateGroupDataStore
from synapse.storage.databases.state.deletion import StateDeletionDataStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database

Expand All @@ -49,12 +50,14 @@ class Databases(Generic[DataStoreT]):
main
state
persist_events
state_deletion
"""

databases: List[DatabasePool]
main: "DataStore" # FIXME: https://github.com/matrix-org/synapse/issues/11165: actually an instance of `main_store_class`
state: StateGroupDataStore
persist_events: Optional[PersistEventsStore]
state_deletion: StateDeletionDataStore

def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
# Note we pass in the main store class here as workers use a different main
Expand All @@ -63,6 +66,7 @@ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
self.databases = []
main: Optional[DataStoreT] = None
state: Optional[StateGroupDataStore] = None
state_deletion: Optional[StateDeletionDataStore] = None
persist_events: Optional[PersistEventsStore] = None

for database_config in hs.config.database.databases:
Expand Down Expand Up @@ -114,7 +118,8 @@ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
if state:
raise Exception("'state' data store already configured")

state = StateGroupDataStore(database, db_conn, hs)
state_deletion = StateDeletionDataStore(database, db_conn, hs)
state = StateGroupDataStore(database, db_conn, hs, state_deletion)

db_conn.commit()

Expand All @@ -135,11 +140,12 @@ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
if not main:
raise Exception("No 'main' database configured")

if not state:
if not state or not state_deletion:
raise Exception("No 'state' database configured")

# We use local variables here to ensure that the databases do not have
# optional types.
self.main = main # type: ignore[assignment]
self.state = state
self.persist_events = persist_events
self.state_deletion = state_deletion
Loading

0 comments on commit d5e89a3

Please sign in to comment.