diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index c85deaed562..24684f57190 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -151,6 +151,8 @@ class FederationEventHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self._store = hs.get_datastores().main + self._state_store = hs.get_datastores().state + self._state_epoch_store = hs.get_datastores().state_epochs self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state @@ -580,7 +582,9 @@ async def process_remote_join( room_version.identifier, state_maps_to_resolve, event_map=None, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_epoch_store + ), ) ) else: @@ -1179,7 +1183,9 @@ async def _compute_event_context_with_maybe_missing_prevs( room_version, state_maps, event_map={event_id: event}, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_epoch_store + ), ) except Exception as e: @@ -1874,7 +1880,9 @@ async def _check_event_auth( room_version, [local_state_id_map, claimed_auth_events_id_map], event_map=None, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_epoch_store + ), ) ) else: @@ -2014,7 +2022,9 @@ async def _check_for_soft_fail( room_version, state_sets, event_map=None, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_epoch_store + ), ) ) else: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 72b291889bb..a92cddc4617 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -59,11 +59,13 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func +from synapse.util.stringutils import shortstr if TYPE_CHECKING: from synapse.server import HomeServer from synapse.storage.controllers import StateStorageController from synapse.storage.databases.main import DataStore + from synapse.storage.databases.state.epochs import StateEpochDataStore logger = logging.getLogger(__name__) metrics_logger = logging.getLogger("synapse.state.metrics") @@ -194,6 +196,8 @@ def __init__(self, hs: "HomeServer"): self._storage_controllers = hs.get_storage_controllers() self._events_shard_config = hs.config.worker.events_shard_config self._instance_name = hs.get_instance_name() + self._state_store = hs.get_datastores().state + self._state_epoch_store = hs.get_datastores().state_epochs self._update_current_state_client = ( ReplicationUpdateCurrentStateRestServlet.make_client(hs) @@ -475,7 +479,10 @@ async def compute_event_context( @trace @measure_func() async def resolve_state_groups_for_events( - self, room_id: str, event_ids: StrCollection, await_full_state: bool = True + self, + room_id: str, + event_ids: StrCollection, + await_full_state: bool = True, ) -> _StateCacheEntry: """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -511,6 +518,19 @@ async def resolve_state_groups_for_events( ) = await self._state_storage_controller.get_state_group_delta( state_group_id ) + + if prev_group: + # Ensure that we still have the prev group, and ensure we don't + # delete it while we're persisting the event. + missing_state_group = ( + await self._state_epoch_store.check_state_groups_and_bump_deletion( + {prev_group} + ) + ) + if missing_state_group: + prev_group = None + delta_ids = None + return _StateCacheEntry( state=None, state_group=state_group_id, @@ -531,7 +551,7 @@ async def resolve_state_groups_for_events( room_version, state_to_resolve, None, - state_res_store=StateResolutionStore(self.store), + state_res_store=StateResolutionStore(self.store, self._state_epoch_store), ) return result @@ -663,7 +683,25 @@ async def resolve_state_groups( async with self.resolve_linearizer.queue(group_names): cache = self._state_cache.get(group_names, None) if cache: - return cache + # Check that the returned cache entry doesn't point to deleted + # state groups. + state_groups_to_check = set() + if cache.state_group is not None: + state_groups_to_check.add(cache.state_group) + + if cache.prev_group is not None: + state_groups_to_check.add(cache.prev_group) + + missing_state_groups = await state_res_store.state_epoch_store.check_state_groups_and_bump_deletion( + state_groups_to_check + ) + + if not missing_state_groups: + return cache + else: + # There are missing state groups, so let's remove the stale + # entry and continue as if it was a cache miss. + self._state_cache.pop(group_names, None) logger.info( "Resolving state for %s with groups %s", @@ -671,6 +709,16 @@ async def resolve_state_groups( list(group_names), ) + # We double check that none of the state groups have been deleted. + # They shouldn't be as all these state groups should be referenced. + missing_state_groups = await state_res_store.state_epoch_store.check_state_groups_and_bump_deletion( + group_names + ) + if missing_state_groups: + raise Exception( + f"State groups have been deleted: {shortstr(missing_state_groups)}" + ) + state_groups_histogram.observe(len(state_groups_ids)) new_state = await self.resolve_events_with_store( @@ -884,7 +932,8 @@ class StateResolutionStore: in well defined way. """ - store: "DataStore" + main_store: "DataStore" + state_epoch_store: "StateEpochDataStore" def get_events( self, event_ids: StrCollection, allow_rejected: bool = False @@ -899,7 +948,7 @@ def get_events( An awaitable which resolves to a dict from event_id to event. """ - return self.store.get_events( + return self.main_store.get_events( event_ids, redact_behaviour=EventRedactBehaviour.as_is, get_prev_content=False, @@ -920,4 +969,4 @@ def get_auth_chain_difference( An awaitable that resolves to a set of event IDs. """ - return self.store.get_auth_chain_difference(room_id, state_sets) + return self.main_store.get_auth_chain_difference(room_id, state_sets) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 879ee9039e1..54c638f2adb 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -332,6 +332,7 @@ def __init__( # store for now. self.main_store = stores.main self.state_store = stores.state + self._state_epoch_store = stores.state_epochs assert stores.persist_events self.persist_events_store = stores.persist_events @@ -549,7 +550,9 @@ async def _calculate_current_state(self, room_id: str) -> StateMap[str]: room_version, state_maps_by_state_group, event_map=None, - state_res_store=StateResolutionStore(self.main_store), + state_res_store=StateResolutionStore( + self.main_store, self._state_epoch_store + ), ) return await res.get_state(self._state_controller, StateFilter.all()) @@ -635,15 +638,20 @@ async def _persist_event_batch( room_id, [e for e, _ in chunk] ) - await self.persist_events_store._persist_events_and_state_updates( - room_id, - chunk, - state_delta_for_room=state_delta_for_room, - new_forward_extremities=new_forward_extremities, - use_negative_stream_ordering=backfilled, - inhibit_local_membership_updates=backfilled, - new_event_links=new_event_links, - ) + # Stop the state groups from being deleted while we're persisting + # them. + async with self._state_epoch_store.persisting_state_group_references( + events_and_contexts + ): + await self.persist_events_store._persist_events_and_state_updates( + room_id, + chunk, + state_delta_for_room=state_delta_for_room, + new_forward_extremities=new_forward_extremities, + use_negative_stream_ordering=backfilled, + inhibit_local_membership_updates=backfilled, + new_event_links=new_event_links, + ) return replaced_events @@ -965,7 +973,9 @@ async def _get_new_state_after_events( room_version, state_groups, events_map, - state_res_store=StateResolutionStore(self.main_store), + state_res_store=StateResolutionStore( + self.main_store, self._state_epoch_store + ), ) state_resolutions_during_persistence.inc() diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index dd9fc01fb0c..98940f56c87 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -26,6 +26,7 @@ from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.databases.state import StateGroupDataStore +from synapse.storage.databases.state.epochs import StateEpochDataStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database @@ -49,12 +50,14 @@ class Databases(Generic[DataStoreT]): main state persist_events + state_epochs """ databases: List[DatabasePool] main: "DataStore" # FIXME: https://github.com/matrix-org/synapse/issues/11165: actually an instance of `main_store_class` state: StateGroupDataStore persist_events: Optional[PersistEventsStore] + state_epochs: StateEpochDataStore def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): # Note we pass in the main store class here as workers use a different main @@ -63,6 +66,7 @@ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): self.databases = [] main: Optional[DataStoreT] = None state: Optional[StateGroupDataStore] = None + state_epochs: Optional[StateEpochDataStore] = None persist_events: Optional[PersistEventsStore] = None for database_config in hs.config.database.databases: @@ -114,7 +118,8 @@ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): if state: raise Exception("'state' data store already configured") - state = StateGroupDataStore(database, db_conn, hs) + state_epochs = StateEpochDataStore(database, db_conn, hs) + state = StateGroupDataStore(database, db_conn, hs, state_epochs) db_conn.commit() @@ -135,7 +140,7 @@ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): if not main: raise Exception("No 'main' database configured") - if not state: + if not state or not state_epochs: raise Exception("No 'state' database configured") # We use local variables here to ensure that the databases do not have @@ -143,3 +148,4 @@ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): self.main = main # type: ignore[assignment] self.state = state self.persist_events = persist_events + self.state_epochs = state_epochs diff --git a/synapse/storage/databases/state/epochs.py b/synapse/storage/databases/state/epochs.py new file mode 100644 index 00000000000..f813f2feea2 --- /dev/null +++ b/synapse/storage/databases/state/epochs.py @@ -0,0 +1,304 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + + +import contextlib +from typing import ( + TYPE_CHECKING, + AbstractSet, + AsyncIterator, + Collection, + Set, + Tuple, +) + +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) +from synapse.storage.engines import PostgresEngine +from synapse.util.stringutils import shortstr + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class StateEpochDataStore: + """Manages state epochs and checks for state group deletion. + + Deleting state groups is challenging as before we actually delete them we + need to ensure that there are no in-flight events that refer to the state + groups that we want to delete. + + To handle this, we take two approaches. First, before we persist any event + we ensure that the state groups still exist and mark in the + `state_groups_persisting` table that the state group is about to be used. + (Note that we have to have the extra table here as state groups and events + can be in different databases, and thus we can't check for the existence of + state groups in the persist event transaction). Once the event has been + persisted, we can remove the row from `state_groups_persisting`. So long as + we check that table before deleting state groups, we can ensure that we + never persist events that reference deleted state groups, maintaining + database integrity. + + However, we want to avoid throwing exceptions so deep in the process of + persisting events. So we use a concept of `state_epochs`, where we mark + state groups as pending/proposed for deletion and wait for a certain number + epoch increments before performing the deletion. When we come to handle new + events that reference state groups, we check if they are pending deletion + and bump the epoch when they'll be deleted in (to give a chance for the + event to be persisted, or not). + """ + + # How frequently, roughly, to increment epochs. + TIME_BETWEEN_EPOCH_INCREMENTS_MS = 5 * 60 * 1000 + + # The number of epoch increases that must have happened between marking a + # state group as pending and actually deleting it. + NUMBER_EPOCHS_BEFORE_DELETION = 3 + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + self._clock = hs.get_clock() + self.db_pool = database + self._instance_name = hs.get_instance_name() + + # TODO: Clear from `state_groups_persisting` any holdovers from previous + # running instance. + + if hs.config.worker.run_background_tasks: + # Add a background loop to periodically check if we should bump + # state epoch. + self._clock.looping_call_now( + self._advance_state_epoch, self.TIME_BETWEEN_EPOCH_INCREMENTS_MS / 5 + ) + + @wrap_as_background_process("_advance_state_epoch") + async def _advance_state_epoch(self) -> None: + """Advances the state epoch, checking that we haven't advanced it too + recently. + """ + + now = self._clock.time_msec() + update_if_before_ts = now - self.TIME_BETWEEN_EPOCH_INCREMENTS_MS + + def advance_state_epoch_txn(txn: LoggingTransaction) -> None: + sql = """ + UPDATE state_epoch + SET state_epoch = state_epoch + 1, updated_ts = ? + WHERE updated_ts <= ? + """ + txn.execute(sql, (now, update_if_before_ts)) + + await self.db_pool.runInteraction( + "_advance_state_epoch", advance_state_epoch_txn, db_autocommit=True + ) + + async def check_state_groups_and_bump_deletion( + self, state_groups: AbstractSet[int] + ) -> Collection[int]: + """Checks to make sure that the state groups haven't been deleted, and + if they're pending deletion we delay it (allowing time for any event + that will use them to finish persisting). + + Returns: + The state groups that are missing, if any. + """ + + return await self.db_pool.runInteraction( + "check_state_groups_and_bump_deletion", + self._check_state_groups_and_bump_deletion_txn, + state_groups, + ) + + def _check_state_groups_and_bump_deletion_txn( + self, txn: LoggingTransaction, state_groups: AbstractSet[int] + ) -> Collection[int]: + existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups) + if state_groups - existing_state_groups: + return state_groups - existing_state_groups + + clause, args = make_in_list_sql_clause( + self.db_pool.engine, "state_group", state_groups + ) + sql = f""" + UPDATE state_groups_pending_deletion + SET state_epoch = (SELECT state_epoch FROM state_epoch) + WHERE {clause} + """ + + txn.execute(sql, args) + + return () + + def _get_existing_groups_with_lock( + self, txn: LoggingTransaction, state_groups: Collection[int] + ) -> AbstractSet[int]: + """Return which of the given state groups are in the database, and locks + those rows with `KEY SHARE` to ensure they don't get concurrently + deleted.""" + clause, args = make_in_list_sql_clause(self.db_pool.engine, "id", state_groups) + + sql = f""" + SELECT id FROM state_groups + WHERE {clause} + """ + if isinstance(self.db_pool.engine, PostgresEngine): + # On postgres we add a row level lock to the rows to ensure that we + # conflict with any concurrent DELETEs. `FOR KEY SHARE` lock will + # not conflict with other read + sql += """ + FOR KEY SHARE + """ + + txn.execute(sql, args) + return {state_group for (state_group,) in txn} + + @contextlib.asynccontextmanager + async def persisting_state_group_references( + self, event_and_contexts: Collection[Tuple[EventBase, EventContext]] + ) -> AsyncIterator[None]: + """Wraps the persistence of the given events and contexts, ensuring that + any state groups referenced still exist and that they don't get deleted + during this.""" + + referenced_state_groups: Set[int] = set() + for event, ctx in event_and_contexts: + if ctx.rejected or event.internal_metadata.is_outlier(): + continue + + assert ctx.state_group is not None + + referenced_state_groups.add(ctx.state_group) + + if ctx.state_group_before_event: + referenced_state_groups.add(ctx.state_group_before_event) + + if not referenced_state_groups: + # We don't reference any state groups, so nothing to do + yield + return + + await self.db_pool.runInteraction( + "mark_state_groups_as_used", + self._mark_state_groups_as_used_txn, + referenced_state_groups, + ) + + try: + yield None + finally: + await self.db_pool.simple_delete_many( + table="state_groups_persisting", + column="state_group", + iterable=referenced_state_groups, + keyvalues={"instance_name": self._instance_name}, + desc="persisting_state_group_references_delete", + ) + + def _mark_state_groups_as_used_txn( + self, txn: LoggingTransaction, state_groups: Set[int] + ) -> None: + """Marks the given state groups as used. Also checks that the given + state epoch is not too old.""" + + existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups) + missing_state_groups = state_groups - existing_state_groups + if missing_state_groups: + raise Exception( + f"state groups have been deleted: {shortstr(missing_state_groups)}" + ) + + self.db_pool.simple_delete_many_batch_txn( + txn, + table="state_groups_pending_deletion", + keys=("state_group",), + values=[(state_group,) for state_group in state_groups], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_persisting", + keys=("state_group", "instance_name"), + values=[(state_group, self._instance_name) for state_group in state_groups], + ) + + def get_state_groups_that_can_be_purged_txn( + self, txn: LoggingTransaction, state_groups: Collection[int] + ) -> Collection[int]: + """Given a set of state groups, return which state groups can be deleted.""" + + if not state_groups: + return state_groups + + if isinstance(self.db_pool.engine, PostgresEngine): + # On postgres we want to lock the rows FOR UPDATE as early as + # possible to help conflicts. + clause, args = make_in_list_sql_clause( + self.db_pool.engine, "id", state_groups + ) + sql = """ + SELECT id FROM state_groups + WHERE {clause} + FOR UPDATE + """ + txn.execute(sql, args) + + current_state_epoch = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_epoch", + retcol="state_epoch", + keyvalues={}, + ) + + # Check the deletion status in the DB of the given state groups + clause, args = make_in_list_sql_clause( + self.db_pool.engine, column="state_group", iterable=state_groups + ) + + sql = f""" + SELECT state_group, state_epoch FROM ( + SELECT state_group, state_epoch FROM state_groups_pending_deletion + UNION + SELECT state_group, null FROM state_groups_persisting + ) AS s + WHERE {clause} + """ + + txn.execute(sql, args) + + can_delete = set() + for state_group, state_epoch in txn: + if state_epoch is None: + # A null state epoch means that we are currently persisting + # events that reference the state group, so we don't delete + # them. + continue + + if current_state_epoch - state_epoch < self.NUMBER_EPOCHS_BEFORE_DELETION: + # Not enough state epochs have occurred to allow us to delete. + continue + + can_delete.add(state_group) + + return can_delete diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 9944f90015c..b2ef3703c33 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -36,7 +36,10 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase +from synapse.events.snapshot import ( + UnpersistedEventContext, + UnpersistedEventContextBase, +) from synapse.logging.opentracing import tag_args, trace from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -55,6 +58,7 @@ if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.state.epochs import StateEpochDataStore logger = logging.getLogger(__name__) @@ -83,8 +87,10 @@ def __init__( database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", + epoch_store: "StateEpochDataStore", ): super().__init__(database, db_conn, hs) + self._epoch_store = epoch_store # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering @@ -467,14 +473,13 @@ def insert_deltas_group_txn( Returns: A list of state groups """ - is_in_db = self.db_pool.simple_select_one_onecol_txn( + + # We need to check that the prev group isn't about to be deleted + is_missing = self._epoch_store._check_state_groups_and_bump_deletion_txn( txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, + {prev_group}, ) - if not is_in_db: + if is_missing: raise Exception( "Trying to persist state with unpersisted prev_group: %r" % (prev_group,) @@ -546,6 +551,7 @@ def insert_deltas_group_txn( for key, state_id in context.state_delta_due_to_event.items() ], ) + return events_and_context return await self.db_pool.runInteraction( @@ -601,14 +607,13 @@ def insert_delta_group_txn( The state group if successfully created, or None if the state needs to be persisted as a full state. """ - is_in_db = self.db_pool.simple_select_one_onecol_txn( + + # We need to check that the prev group isn't about to be deleted + is_missing = self._epoch_store._check_state_groups_and_bump_deletion_txn( txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, + {prev_group}, ) - if not is_in_db: + if is_missing: raise Exception( "Trying to persist state with unpersisted prev_group: %r" % (prev_group,) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 934e1cccedb..0b54728ea5b 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -19,7 +19,7 @@ # # -SCHEMA_VERSION = 88 # remember to update the list below when updating +SCHEMA_VERSION = 89 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the diff --git a/synapse/storage/schema/state/delta/89/01_state_groups_epochs.sql b/synapse/storage/schema/state/delta/89/01_state_groups_epochs.sql new file mode 100644 index 00000000000..696e2fdf5f4 --- /dev/null +++ b/synapse/storage/schema/state/delta/89/01_state_groups_epochs.sql @@ -0,0 +1,37 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + + +CREATE TABLE IF NOT EXISTS state_epoch ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + state_epoch BIGINT NOT NULL, + updated_ts BIGINT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO state_epoch (state_epoch, updated_ts) VALUES (0, 0); + +CREATE TABLE IF NOT EXISTS state_groups_pending_deletion ( + state_group BIGINT NOT NULL, + state_epoch BIGINT NOT NULL, + PRIMARY KEY (state_group, state_epoch) +); + +CREATE INDEX state_groups_pending_deletion_epoch ON state_groups_pending_deletion(state_epoch); + + +CREATE TABLE IF NOT EXISTS state_groups_persisting ( + state_group BIGINT NOT NULL, + instance_name TEXT NOT NULL, + PRIMARY KEY (state_group, instance_name) +); diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 5db10fa74c2..213620caae3 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -807,6 +807,7 @@ def test_process_pulled_event_with_rejected_missing_state(self) -> None: OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" main_store = self.hs.get_datastores().main + epoch_store = self.hs.get_datastores().state_epochs # Create the room. kermit_user_id = self.register_user("kermit", "test") @@ -958,7 +959,7 @@ def test_process_pulled_event_with_rejected_missing_state(self) -> None: bert_member_event.event_id: bert_member_event, rejected_kick_event.event_id: rejected_kick_event, }, - state_res_store=StateResolutionStore(main_store), + state_res_store=StateResolutionStore(main_store, epoch_store), ) ), [bert_member_event.event_id, rejected_kick_event.event_id], @@ -1003,7 +1004,7 @@ def test_process_pulled_event_with_rejected_missing_state(self) -> None: rejected_power_levels_event.event_id, ], event_map={}, - state_res_store=StateResolutionStore(main_store), + state_res_store=StateResolutionStore(main_store, epoch_store), full_conflicted_set=set(), ) ), diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 4cf1a3dc519..5473a8e7698 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -742,7 +742,7 @@ def test_post_room_no_keys(self) -> None: self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(33, channel.resource_usage.db_txn_count) + self.assertEqual(35, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -755,7 +755,7 @@ def test_post_room_initial_state(self) -> None: self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(35, channel.resource_usage.db_txn_count) + self.assertEqual(37, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id diff --git a/tests/test_state.py b/tests/test_state.py index 311a5906935..ec6f84b850e 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -31,7 +31,7 @@ Tuple, cast, ) -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.internet import defer @@ -221,7 +221,16 @@ def walk(self) -> Iterator[EventBase]: class StateTestCase(unittest.TestCase): def setUp(self) -> None: self.dummy_store = _DummyStore() - storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store) + + # Add a dummy epoch store that always retruns that we have all the + # necessary state groups. + dummy_epoch_store = AsyncMock() + dummy_epoch_store.check_state_groups_and_bump_deletion.return_value = [] + + storage_controllers = Mock( + main=self.dummy_store, + state=self.dummy_store, + ) hs = Mock( spec_set=[ "config", @@ -241,7 +250,10 @@ def setUp(self) -> None: ) clock = cast(Clock, MockClock()) hs.config = default_config("tesths", True) - hs.get_datastores.return_value = Mock(main=self.dummy_store) + hs.get_datastores.return_value = Mock( + main=self.dummy_store, + state_epochs=dummy_epoch_store, + ) hs.get_state_handler.return_value = None hs.get_clock.return_value = clock hs.get_macaroon_generator.return_value = MacaroonGenerator(