Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Correctly handle unpersisted events when calculating auth chain difference. #8827

Merged
merged 8 commits into from
Dec 2, 2020
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions synapse/state/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
from synapse.types import Collection, MutableStateMap, StateMap
from synapse.util import Clock

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -257,16 +257,19 @@ async def _get_auth_chain_difference(
# previously. This is not the case for any events in the `event_map`, and so
# we need to manually handle those events.
#
# We do this by calculating the auth chain difference based on events in
# `event_map` and adding that to the result from calling
# `get_auth_chain_difference` with state sets where we've replaced
# references to events in `event_map` with their auth events (recursively).
# We do this by:
# 1. calculating the auth chain difference for the state sets based on the
# events in `event_map` alone
# 2. replacing any events in the state_sets that are also in `event_map`
# with their auth events (recursively), and then calling
# `store.get_auth_chain_difference` as normal
# 3. adding the results of 1 and 2 together.

# Map from event ID in `event_map` to their auth event IDs, and their auth
# event IDs if they appear in the `event_map`. This is the intersection of
# the event's auth chain with the events in the `event_map` *plus* their
# auth event IDs.
events_to_auth_chain = {}
events_to_auth_chain = {} # type: Dict[str, Set[str]]
for event in event_map.values():
chain = {event.event_id}
events_to_auth_chain[event.event_id] = chain
Expand All @@ -284,10 +287,15 @@ async def _get_auth_chain_difference(
#
# Note: If the `event_map` is empty (which is the common case), we can do a
# much simpler calculation.
difference_from_event_map = set()
if event_map:
state_sets_ids = []
unpersisted_set_ids = []
# The list of state sets to pass to the store. This is the same as
# `state_sets` except with unpersisted events stripped out and replaced
# with persisted events in their auth chain.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
state_sets_ids = [] # type: List[Set[str]]

# List of sets of the unpersisted event IDs reachable (by their auth
# chain) from each state set.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
unpersisted_set_ids = [] # type: List[Set[str]]

for state_set in state_sets:
set_ids = set() # type: Set[str]
Expand All @@ -300,7 +308,7 @@ async def _get_auth_chain_difference(
event_chain = events_to_auth_chain.get(event_id)
if event_chain is not None:
# We have an event in `event_map`. We add all the auth
# events that it reference's (that aren't also in `event_map`).
# events that it references (that aren't also in `event_map`).
set_ids.update(e for e in event_chain if e not in event_map)

# We also add the full chain of unpersisted event IDs
Expand All @@ -311,13 +319,14 @@ async def _get_auth_chain_difference(
set_ids.add(event_id)

# The auth chain difference of the unpersisted events of the state sets
# is calcualted by taking the difference between the union and
# is calculated by taking the difference between the union and
# intersections.
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])

difference_from_event_map = union - intersection
difference_from_event_map = union - intersection # type: Collection[str]
else:
difference_from_event_map = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets]

difference = await state_res_store.get_auth_chain_difference(state_sets_ids)
Expand Down