From cacd4fd7bd40465732fc302a69efa39dcb5eb118 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 8 Nov 2024 16:41:24 +0000 Subject: [PATCH] Fix MSC4222 returning full state (#17915) There was a bug that meant we would return the full state of the room on incremental syncs when using lazy loaded members and there were no entries in the timeline. This was due to trying to use `state_filter or state_filter.all()` as a short hand for handling `None` case, however `state_filter` implements `__bool__` so if the state filter was empty it would be set to full. c.f. MSC4222 and #17888 --- changelog.d/17915.bugfix | 1 + synapse/handlers/message.py | 4 +- synapse/handlers/sync.py | 2 +- synapse/storage/controllers/state.py | 56 ++++++++++++------- synapse/storage/databases/main/state.py | 8 +-- synapse/storage/databases/state/bg_updates.py | 4 +- synapse/storage/databases/state/store.py | 3 +- synapse/types/state.py | 12 +++- tests/handlers/test_sync.py | 32 +++++++++++ 9 files changed, 91 insertions(+), 31 deletions(-) create mode 100644 changelog.d/17915.bugfix diff --git a/changelog.d/17915.bugfix b/changelog.d/17915.bugfix new file mode 100644 index 00000000000..a5d82e486db --- /dev/null +++ b/changelog.d/17915.bugfix @@ -0,0 +1 @@ +Fix experimental support for [MSC4222](https://github.com/matrix-org/matrix-spec-proposals/pull/4222) where we would return the full state on incremental syncs when using lazy loaded members and there were no new events in the timeline. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 204965afeec..df3010ecf68 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -196,7 +196,9 @@ async def get_state_events( AuthError (403) if the user doesn't have permission to view members of this room. """ - state_filter = state_filter or StateFilter.all() + if state_filter is None: + state_filter = StateFilter.all() + user_id = requester.user.to_string() if at_token: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index df9a088063f..350c3fa09a1 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1520,7 +1520,7 @@ async def _compute_state_delta_for_incremental_sync( if sync_config.use_state_after: delta_state_ids: MutableStateMap[str] = {} - if members_to_fetch is not None: + if members_to_fetch: # We're lazy-loading, so the client might need some more member # events to understand the events in this timeline. So we always # fish out all the member events corresponding to the timeline diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index b50eb8868ec..f28f5d7e039 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -234,8 +234,11 @@ async def get_state_for_events( RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ + if state_filter is None: + state_filter = StateFilter.all() + await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + if not state_filter.must_await_full_state(self._is_mine_id): await_full_state = False event_to_groups = await self.get_state_group_for_events( @@ -244,7 +247,7 @@ async def get_state_for_events( groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() + groups, state_filter ) state_event_map = await self.stores.main.get_events( @@ -292,10 +295,11 @@ async def get_state_ids_for_events( RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - if ( - await_full_state - and state_filter - and not state_filter.must_await_full_state(self._is_mine_id) + if state_filter is None: + state_filter = StateFilter.all() + + if await_full_state and not state_filter.must_await_full_state( + self._is_mine_id ): # Full state is not required if the state filter is restrictive enough. await_full_state = False @@ -306,7 +310,7 @@ async def get_state_ids_for_events( groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() + groups, state_filter ) event_to_state = { @@ -335,9 +339,10 @@ async def get_state_for_event( RuntimeError if we don't have a state group for the event (ie it is an outlier or is unknown) """ - state_map = await self.get_state_for_events( - [event_id], state_filter or StateFilter.all() - ) + if state_filter is None: + state_filter = StateFilter.all() + + state_map = await self.get_state_for_events([event_id], state_filter) return state_map[event_id] @trace @@ -365,9 +370,12 @@ async def get_state_ids_for_event( RuntimeError if we don't have a state group for the event (ie it is an outlier or is unknown) """ + if state_filter is None: + state_filter = StateFilter.all() + state_map = await self.get_state_ids_for_events( [event_id], - state_filter or StateFilter.all(), + state_filter, await_full_state=await_full_state, ) return state_map[event_id] @@ -388,9 +396,12 @@ async def get_state_after_event( at the event and `state_filter` is not satisfied by partial state. Defaults to `True`. """ + if state_filter is None: + state_filter = StateFilter.all() + state_ids = await self.get_state_ids_for_event( event_id, - state_filter=state_filter or StateFilter.all(), + state_filter=state_filter, await_full_state=await_full_state, ) @@ -426,6 +437,9 @@ async def get_state_ids_at( at the last event in the room before `stream_position` and `state_filter` is not satisfied by partial state. Defaults to `True`. """ + if state_filter is None: + state_filter = StateFilter.all() + # FIXME: This gets the state at the latest event before the stream ordering, # which might not be the same as the "current state" of the room at the time # of the stream token if there were multiple forward extremities at the time. @@ -442,7 +456,7 @@ async def get_state_ids_at( if last_event_id: state = await self.get_state_after_event( last_event_id, - state_filter=state_filter or StateFilter.all(), + state_filter=state_filter, await_full_state=await_full_state, ) @@ -500,9 +514,10 @@ async def get_state_for_groups( Returns: Dict of state group to state map. """ - return await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) + if state_filter is None: + state_filter = StateFilter.all() + + return await self.stores.state._get_state_for_groups(groups, state_filter) @trace @tag_args @@ -583,12 +598,13 @@ async def get_current_state_ids( Returns: The current state of the room. """ - if await_full_state and ( - not state_filter or state_filter.must_await_full_state(self._is_mine_id) - ): + if state_filter is None: + state_filter = StateFilter.all() + + if await_full_state and state_filter.must_await_full_state(self._is_mine_id): await self._partial_state_room_tracker.await_full_state(room_id) - if state_filter and not state_filter.is_full(): + if state_filter is not None and not state_filter.is_full(): return await self.stores.main.get_partial_filtered_current_state_ids( room_id, state_filter ) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 42b3638e1c8..788f7d1e325 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -572,10 +572,10 @@ async def get_partial_filtered_current_state_ids( Returns: Map from type/state_key to event ID. """ + if state_filter is None: + state_filter = StateFilter.all() - where_clause, where_args = ( - state_filter or StateFilter.all() - ).make_sql_filter_clause() + where_clause, where_args = (state_filter).make_sql_filter_clause() if not where_clause: # We delegate to the cached version @@ -584,7 +584,7 @@ async def get_partial_filtered_current_state_ids( def _get_filtered_current_state_ids_txn( txn: LoggingTransaction, ) -> StateMap[str]: - results = StateMapWrapper(state_filter=state_filter or StateFilter.all()) + results = StateMapWrapper(state_filter=state_filter) sql = """ SELECT type, state_key, event_id FROM current_state_events diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index ea7d8199a7d..f7824cba0f2 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -112,8 +112,8 @@ def _get_state_groups_from_groups_txn( Returns: Map from state_group to a StateMap at that point. """ - - state_filter = state_filter or StateFilter.all() + if state_filter is None: + state_filter = StateFilter.all() results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 875dba33496..f7a59c8992d 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -284,7 +284,8 @@ async def _get_state_for_groups( Returns: Dict of state group to state map. """ - state_filter = state_filter or StateFilter.all() + if state_filter is None: + state_filter = StateFilter.all() member_filter, non_member_filter = state_filter.get_member_split() diff --git a/synapse/types/state.py b/synapse/types/state.py index 67d1c3fe972..e641215f184 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -68,15 +68,23 @@ class StateFilter: include_others: bool = False def __attrs_post_init__(self) -> None: - # If `include_others` is set we canonicalise the filter by removing - # wildcards from the types dictionary if self.include_others: + # If `include_others` is set we canonicalise the filter by removing + # wildcards from the types dictionary + # this is needed to work around the fact that StateFilter is frozen object.__setattr__( self, "types", immutabledict({k: v for k, v in self.types.items() if v is not None}), ) + else: + # Otherwise we remove entries where the value is the empty set. + object.__setattr__( + self, + "types", + immutabledict({k: v for k, v in self.types.items() if v is None or v}), + ) @staticmethod def all() -> "StateFilter": diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 1960d2f0e10..9dd0e98971b 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -1262,3 +1262,35 @@ def test_incremental_sync_multiple_deltas(self) -> None: ) ) self.assertEqual(state[("m.test_event", "")], second_state["event_id"]) + + def test_incremental_sync_lazy_loaded_no_timeline(self) -> None: + """Test that lazy-loading with an empty timeline doesn't return the full + state. + + There was a bug where an empty state filter would cause the DB to return + the full state, rather than an empty set. + """ + user = self.register_user("user", "password") + tok = self.login("user", "password") + + # Create a room as the user and set some custom state. + joined_room = self.helper.create_room_as(user, tok=tok) + + since_token = self.hs.get_event_sources().get_current_token() + end_stream_token = self.hs.get_event_sources().get_current_token() + + state = self.get_success( + self.sync_handler._compute_state_delta_for_incremental_sync( + room_id=joined_room, + sync_config=generate_sync_config(user, use_state_after=True), + batch=TimelineBatch( + prev_batch=end_stream_token, events=[], limited=True + ), + since_token=since_token, + end_token=end_stream_token, + members_to_fetch=set(), + timeline_state={}, + ) + ) + + self.assertEqual(state, {})