Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert events worker database to async/await. (#8071)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Aug 18, 2020
1 parent acfb7c3 commit f40645e
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 97 deletions.
1 change: 1 addition & 0 deletions changelog.d/8071.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
2 changes: 1 addition & 1 deletion synapse/event_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def check(
Args:
room_version_obj: the version of the room
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
auth_events: the existing room state.
Raises:
AuthError if the checks fail
Expand Down
16 changes: 5 additions & 11 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,9 +1777,7 @@ async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase
"""Returns the state at the event. i.e. not including said event.
"""

event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event = await self.store.get_event(event_id, check_room_id=room_id)

state_groups = await self.state_store.get_state_groups(room_id, [event_id])

Expand All @@ -1805,9 +1803,7 @@ async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event = await self.store.get_event(event_id, check_room_id=room_id)

state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])

Expand Down Expand Up @@ -2155,9 +2151,9 @@ async def _check_for_soft_fail(
auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]

current_auth_events = await self.store.get_events(current_state_ids)
auth_events_map = await self.store.get_events(current_state_ids)
current_auth_events = {
(e.type, e.state_key): e for e in current_auth_events.values()
(e.type, e.state_key): e for e in auth_events_map.values()
}

try:
Expand All @@ -2173,9 +2169,7 @@ async def on_query_auth(
if not in_room:
raise AuthError(403, "Host not in room.")

event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event = await self.store.get_event(event_id, check_room_id=room_id)

# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ async def persist_and_notify_client_event(
allow_none=True,
)

is_admin_redaction = (
is_admin_redaction = bool(
original_event and event.sender != original_event.sender
)

Expand Down Expand Up @@ -1080,8 +1080,8 @@ def is_inviter_member_event(e):
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
auth_events_map = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}

room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ async def _can_guest_join(

guest_access = await self.store.get_event(guest_access_id)

return (
return bool(
guest_access
and guest_access.content
and "guest_access" in guest_access.content
Expand Down
2 changes: 1 addition & 1 deletion synapse/spam_checker_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ def get_state_events_in_room(self, room_id: str, types: tuple) -> defer.Deferred
state_ids = yield self._store.get_filtered_current_state_ids(
room_id=room_id, state_filter=StateFilter.from_types(types)
)
state = yield self._store.get_events(state_ids.values())
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
return state.values()
2 changes: 1 addition & 1 deletion synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def get_events(self, event_ids, allow_rejected=False):
allow_rejected (bool): If True return rejected events.
Returns:
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
"""

return self.store.get_events(
Expand Down
30 changes: 14 additions & 16 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def get_auth_chain(self, event_ids, include_given=False):
async def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
Expand All @@ -40,9 +40,10 @@ def get_auth_chain(self, event_ids, include_given=False):
Returns:
list of events
"""
return self.get_auth_chain_ids(
event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
).addCallback(self.get_events_as_list)
)
return await self.get_events_as_list(event_ids)

def get_auth_chain_ids(
self,
Expand Down Expand Up @@ -459,7 +460,7 @@ def get_forward_extremeties_for_room_txn(txn):
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)

def get_backfill_events(self, room_id, event_list, limit):
async def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
Expand All @@ -469,17 +470,15 @@ def get_backfill_events(self, room_id, event_list, limit):
event_list (list)
limit (int)
"""
return (
self.db_pool.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
event_list,
limit,
)
.addCallback(self.get_events_as_list)
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
event_ids = await self.db_pool.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
event_list,
limit,
)
events = await self.get_events_as_list(event_ids)
return sorted(events, key=lambda e: -e.depth)

def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
Expand Down Expand Up @@ -540,8 +539,7 @@ async def get_missing_events(self, room_id, earliest_events, latest_events, limi
latest_events,
limit,
)
events = await self.get_events_as_list(ids)
return events
return await self.get_events_as_list(ids)

def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):

Expand Down
Loading

0 comments on commit f40645e

Please sign in to comment.