Skip to content

Commit

Permalink
Refactor resolve_state_groups_for_events to not pull out full state…
Browse files Browse the repository at this point in the history
… when no state resolution happens. (matrix-org#12775)
  • Loading branch information
H-Shay authored May 18, 2022
1 parent 3d8839c commit 19d79b6
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 23 deletions.
1 change: 1 addition & 0 deletions changelog.d/12775.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens.
35 changes: 19 additions & 16 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ async def compute_event_context(
#
# first of all, figure out the state before the event
#

if old_state:
# if we're given the state before the event, then we use that
state_ids_before_event: StateMap[str] = {
Expand Down Expand Up @@ -419,33 +418,37 @@ async def resolve_state_groups_for_events(
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)

# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
state_groups_ids = await self.state_store.get_state_groups_ids(
room_id, event_ids
)

if len(state_groups_ids) == 0:
return _StateCacheEntry(state={}, state_group=None)
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
state_groups = await self.state_store.get_state_group_for_events(event_ids)

prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
state_group_ids = state_groups.values()

# check if each event has same state group id, if so there's no state to resolve
state_group_ids_set = set(state_group_ids)
if len(state_group_ids_set) == 1:
(state_group_id,) = state_group_ids_set
state = await self.state_store.get_state_for_groups(state_group_ids_set)
prev_group, delta_ids = await self.state_store.get_state_group_delta(
state_group_id
)
return _StateCacheEntry(
state=state_list,
state_group=name,
state=state[state_group_id],
state_group=state_group_id,
prev_group=prev_group,
delta_ids=delta_ids,
)
elif len(state_group_ids_set) == 0:
return _StateCacheEntry(state={}, state_group=None)

room_version = await self.store.get_room_version_id(room_id)

state_to_resolve = await self.state_store.get_state_for_groups(
state_group_ids_set
)

result = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
state_groups_ids,
state_to_resolve,
None,
state_res_store=StateResolutionStore(self.store),
)
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _get_state_for_group_using_cache(
group: int,
state_filter: StateFilter,
) -> Tuple[MutableStateMap[str], bool]:
"""Checks if group is in cache. See `_get_state_for_groups`
"""Checks if group is in cache. See `get_state_for_groups`
Args:
cache: the state group cache to use
Expand Down
12 changes: 6 additions & 6 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ async def get_state_groups_ids(
if not event_ids:
return {}

event_to_groups = await self._get_state_group_for_events(event_ids)
event_to_groups = await self.get_state_group_for_events(event_ids)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
Expand All @@ -602,7 +602,7 @@ async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
Returns:
Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = await self._get_state_for_groups((state_group,))
group_to_state = await self.get_state_for_groups((state_group,))

return group_to_state[state_group]

Expand Down Expand Up @@ -675,7 +675,7 @@ async def get_state_for_events(
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self._get_state_group_for_events(event_ids)
event_to_groups = await self.get_state_group_for_events(event_ids)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
Expand Down Expand Up @@ -716,7 +716,7 @@ async def get_state_ids_for_events(
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self._get_state_group_for_events(event_ids)
event_to_groups = await self.get_state_group_for_events(event_ids)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
Expand Down Expand Up @@ -774,7 +774,7 @@ async def get_state_ids_for_event(
)
return state_map[event_id]

def _get_state_for_groups(
def get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
Expand All @@ -792,7 +792,7 @@ def _get_state_for_groups(
groups, state_filter or StateFilter.all()
)

async def _get_state_group_for_events(
async def get_state_group_for_events(
self,
event_ids: Collection[str],
await_full_state: bool = True,
Expand Down
13 changes: 13 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,19 @@ def register_event_id_state_group(self, event_id, state_group):
async def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier

async def get_state_group_for_events(self, event_ids):
res = {}
for event in event_ids:
res[event] = self._event_to_state_group[event]
return res

async def get_state_for_groups(self, groups):
res = {}
for group in groups:
state = self._group_to_state[group]
res[group] = state
return res


class DictObj(dict):
def __init__(self, **kwargs):
Expand Down

0 comments on commit 19d79b6

Please sign in to comment.