Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Factor out common code for persisting fetched auth events #10896

Merged
1 change: 1 addition & 0 deletions changelog.d/10896.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Clean up some of the federation event authentication code for clarity.
2 changes: 0 additions & 2 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,6 @@ async def get_event_auth(
destination, auth_chain, outlier=True, room_version=room_version
)

signed_auth.sort(key=lambda e: e.depth)

return signed_auth

def _is_unknown_endpoint(
Expand Down
103 changes: 48 additions & 55 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,7 @@ async def _get_events_and_persist(

room_version = await self._store.get_room_version(room_id)

event_map: Dict[str, EventBase] = {}
events: List[EventBase] = []

async def get_event(event_id: str) -> None:
with nested_logging_context(event_id):
Expand All @@ -1125,8 +1125,7 @@ async def get_event(event_id: str) -> None:
event_id,
)
return

event_map[event.event_id] = event
events.append(event)

except Exception as e:
logger.warning(
Expand All @@ -1137,11 +1136,29 @@ async def get_event(event_id: str) -> None:
)

await concurrently_execute(get_event, event_ids, 5)
logger.info("Fetched %i events of %i requested", len(event_map), len(event_ids))
logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
await self._auth_and_persist_fetched_events(destination, room_id, events)

async def _auth_and_persist_fetched_events(
self, origin: str, room_id: str, events: Iterable[EventBase]
) -> None:
"""Persist the events fetched by _get_events_and_persist or _get_remote_auth_chain_for_event

The events to be persisted must be outliers.

We first sort the events to make sure that we process each event's auth_events
before the event itself, and then auth and persist them.

Notifies about the events where appropriate.

Params:
origin: where the events came from
room_id: the room that the events are meant to be in (though this has
not yet been checked)
events: the events that have been fetched
"""
event_map = {event.event_id: event for event in events}

# we now need to auth the events in an order which ensures that each event's
# auth_events are authed before the event itself.
#
# XXX: it might be possible to kick this process off in parallel with fetching
# the events.
while event_map:
Expand All @@ -1168,30 +1185,26 @@ async def get_event(event_id: str) -> None:
"Persisting %i of %i remaining events", len(roots), len(event_map)
)

await self._auth_and_persist_fetched_events(destination, room_id, roots)
await self._auth_and_persist_fetched_events_inner(origin, room_id, roots)
richvdh marked this conversation as resolved.
Show resolved Hide resolved

for ev in roots:
del event_map[ev.event_id]

async def _auth_and_persist_fetched_events(
async def _auth_and_persist_fetched_events_inner(
self, origin: str, room_id: str, fetched_events: Collection[EventBase]
) -> None:
"""Persist the events fetched by _get_events_and_persist.

The events should not depend on one another, e.g. this should be used to persist
a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations.
"""Helper for _auth_and_persist_fetched_events

We also assume that all of the auth events for all of the events have already
been persisted.
Persists a batch of events where we have (theoretically) already persisted all
of their auth events.

Notifies about the events where appropriate.

Params:
origin: where the events came from
room_id: the room that the events are meant to be in (though this has
not yet been checked)
event_id: map from event_id -> event for the fetched events
fetched_events: the events to persist
"""
# get all the auth events for all the events in this batch. By now, they should
# have been persisted.
Expand Down Expand Up @@ -1605,53 +1618,33 @@ async def _get_remote_auth_chain_for_event(
event_id: the event for which we are lacking auth events
"""
try:
remote_auth_chain = await self._federation_client.get_event_auth(
destination, room_id, event_id
)
remote_event_map = {
e.event_id: e
for e in await self._federation_client.get_event_auth(
destination, room_id, event_id
)
}
except RequestSendFailed as e1:
# The other side isn't around or doesn't implement the
# endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e1)
return

seen_remotes = await self._store.have_seen_events(
room_id, [e.event_id for e in remote_auth_chain]
)
logger.info("/event_auth returned %i events", len(remote_event_map))

for auth_event in remote_auth_chain:
if auth_event.event_id in seen_remotes:
continue
# `event` may be returned, but we should not yet process it.
remote_event_map.pop(event_id, None)

if auth_event.event_id == event_id:
continue
# nor should we reprocess any events we have already seen.
seen_remotes = await self._store.have_seen_events(
room_id, remote_event_map.keys()
)
for s in seen_remotes:
remote_event_map.pop(s, None)

try:
auth_ids = auth_event.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in remote_auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
auth_event.internal_metadata.outlier = True

logger.debug(
"_check_event_auth %s missing_auth: %s",
event_id,
auth_event.event_id,
)
missing_auth_event_context = EventContext.for_outlier()
missing_auth_event_context = await self._check_event_auth(
destination,
auth_event,
missing_auth_event_context,
claimed_auth_event_map=auth,
)
await self.persist_events_and_notify(
room_id,
[(auth_event, missing_auth_event_context)],
)
except AuthError:
pass
await self._auth_and_persist_fetched_events(
destination, room_id, remote_event_map.values()
)

async def _update_context_for_auth_events(
self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
Expand Down
7 changes: 6 additions & 1 deletion tests/handlers/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,12 @@ def test_backfill_floating_outlier_membership_auth(self):
async def get_event_auth(
destination: str, room_id: str, event_id: str
) -> List[EventBase]:
return auth_events
return [
event_from_pdu_json(
ae.get_pdu_json(), room_version=room_version, outlier=True
)
for ae in auth_events
]

self.handler.federation_client.get_event_auth = get_event_auth

Expand Down