Skip to content

Commit

Permalink
Skip waiting for full state if a StateFilter does not require it (mat…
Browse files Browse the repository at this point in the history
…rix-org#12498)

If `StateFilter` specifies a state set which we will have regardless of
state-syncing, then we may as well return it immediately.
  • Loading branch information
richvdh authored May 18, 2022
1 parent 0fce474 commit d38c73e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
1 change: 1 addition & 0 deletions changelog.d/12498.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preparation for faster-room-join work: return subsets of room state which we already have, immediately.
63 changes: 59 additions & 4 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,6 +16,7 @@
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
Expand Down Expand Up @@ -532,6 +534,44 @@ def approx_difference(self, other: "StateFilter") -> "StateFilter":
new_all, new_excludes, new_wildcards, new_concrete_keys
)

def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
"""Check if we need to wait for full state to complete to calculate this state
If we have a state filter which is completely satisfied even with partial
state, then we don't need to await_full_state before we can return it.
Args:
is_mine_id: a callable which confirms if a given state_key matches a mxid
of a local user
"""

# TODO(faster_joins): it's not entirely clear that this is safe. In particular,
# there may be circumstances in which we return a piece of state that, once we
# resync the state, we discover is invalid. For example: if it turns out that
# the sender of a piece of state wasn't actually in the room, then clearly that
# state shouldn't have been returned.
# We should at least add some tests around this to see what happens.

# if we haven't requested membership events, then it depends on the value of
# 'include_others'
if EventTypes.Member not in self.types:
return self.include_others

# if we're looking for *all* membership events, then we have to wait
member_state_keys = self.types[EventTypes.Member]
if member_state_keys is None:
return True

# otherwise, consider whose membership we are looking for. If it's entirely
# local users, then we don't need to wait.
for state_key in member_state_keys:
if not is_mine_id(state_key):
# remote user
return True

# local users only
return False


_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
Expand All @@ -544,6 +584,7 @@ class StateGroupStorage:
"""High level interface to fetching state for event."""

def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)

Expand Down Expand Up @@ -675,7 +716,13 @@ async def get_state_for_events(
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self.get_state_group_for_events(event_ids)
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False

event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
Expand All @@ -699,7 +746,9 @@ async def get_state_for_events(
return {event: event_to_state[event] for event in event_ids}

async def get_state_ids_for_events(
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
Expand All @@ -716,7 +765,13 @@ async def get_state_ids_for_events(
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self.get_state_group_for_events(event_ids)
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False

event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
Expand Down Expand Up @@ -802,7 +857,7 @@ async def get_state_group_for_events(
Args:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at this event.
state at these events.
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
Expand Down

0 comments on commit d38c73e

Please sign in to comment.