Skip to content

Commit

Permalink
Convert some of the federation handler methods to async/await. (matri…
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Apr 24, 2020
1 parent 69a1ac0 commit 33bceb7
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
1 change: 1 addition & 0 deletions changelog.d/7338.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert some federation handler code to async/await.
49 changes: 24 additions & 25 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
ours = await self.state_store.get_state_groups_ids(room_id, seen)

# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(ours.values()) # type: list[StateMap[str]]
state_maps = list(ours.values()) # type: List[StateMap[str]]

# we don't need this any more, let's delete it.
del ours
Expand Down Expand Up @@ -1694,16 +1694,15 @@ async def on_send_leave_request(self, origin, pdu):

return None

@defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id):
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event.
"""

event = yield self.store.get_event(
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)

state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
state_groups = await self.state_store.get_state_groups(room_id, [event_id])

if state_groups:
_, state = list(iteritems(state_groups)).pop()
Expand All @@ -1714,7 +1713,7 @@ def get_state_for_pdu(self, room_id, event_id):
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
prev_event = yield self.store.get_event(prev_id)
prev_event = await self.store.get_event(prev_id)
results[(event.type, event.state_key)] = prev_event
else:
del results[(event.type, event.state_key)]
Expand All @@ -1724,15 +1723,14 @@ def get_state_for_pdu(self, room_id, event_id):
else:
return []

@defer.inlineCallbacks
def get_state_ids_for_pdu(self, room_id, event_id):
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
event = yield self.store.get_event(
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)

state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])

if state_groups:
_, state = list(state_groups.items()).pop()
Expand All @@ -1751,49 +1749,50 @@ def get_state_ids_for_pdu(self, room_id, event_id):
else:
return []

@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, pdu_list, limit):
in_room = yield self.auth.check_host_in_room(room_id, origin)
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
) -> List[EventBase]:
in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")

# Synapse asks for 100 events per backfill request. Do not allow more.
limit = min(limit, 100)

events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
events = await self.store.get_backfill_events(room_id, pdu_list, limit)

events = yield filter_events_for_server(self.storage, origin, events)
events = await filter_events_for_server(self.storage, origin, events)

return events

@defer.inlineCallbacks
@log_function
def get_persisted_pdu(self, origin, event_id):
async def get_persisted_pdu(
self, origin: str, event_id: str
) -> Optional[EventBase]:
"""Get an event from the database for the given server.
Args:
origin [str]: hostname of server which is requesting the event; we
origin: hostname of server which is requesting the event; we
will check that the server is allowed to see it.
event_id [str]: id of the event being requested
event_id: id of the event being requested
Returns:
Deferred[EventBase|None]: None if we know nothing about the event;
otherwise the (possibly-redacted) event.
None if we know nothing about the event; otherwise the (possibly-redacted) event.
Raises:
AuthError if the server is not currently in the room
"""
event = yield self.store.get_event(
event = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)

if event:
in_room = yield self.auth.check_host_in_room(event.room_id, origin)
in_room = await self.auth.check_host_in_room(event.room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")

events = yield filter_events_for_server(self.storage, origin, [event])
events = await filter_events_for_server(self.storage, origin, [event])
event = events[0]
return event
else:
Expand Down Expand Up @@ -2397,7 +2396,7 @@ async def _update_context_for_auth_events(
"""
# exclude the state key of the new event from the current_state in the context.
if event.is_state():
event_key = (event.type, event.state_key)
event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
else:
event_key = None
state_updates = {
Expand Down

0 comments on commit 33bceb7

Please sign in to comment.