Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Correctly handle unpersisted events when calculating auth chain difference. #8827

Merged
merged 8 commits into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8827.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where we might not correctly calculate the current state for rooms with multiple extremities.
73 changes: 70 additions & 3 deletions synapse/state/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,76 @@ async def _get_auth_chain_difference(
Set of event IDs
"""

difference = await state_res_store.get_auth_chain_difference(
[set(state_set.values()) for state_set in state_sets]
)
# The `StateResolutionStore.get_auth_chain_difference` function assumes that
# all events passed to it (and their auth chains) have been persisted
# previously. This is not the case for any events in the `event_map`, and so
# we need to manually handle those events.
#
# We do this by calculating the auth chain difference based on events in
# `event_map` and adding that to the result from calling
# `get_auth_chain_difference` with state sets where we've replaced
# references to events in `event_map` with their auth events (recursively).
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

# Map from event ID in `event_map` to their auth event IDs, and their auth
# event IDs if they appear in the `event_map`. This is the intersection of
# the event's auth chain with the events in the `event_map` *plus* their
# auth event IDs.
events_to_auth_chain = {}
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
for event in event_map.values():
chain = {event.event_id}
events_to_auth_chain[event.event_id] = chain

to_search = [event]
while to_search:
for auth_id in to_search.pop().auth_event_ids():
chain.add(auth_id)
auth_event = event_map.get(auth_id)
if auth_event:
to_search.append(auth_event)

# We now a) calculate the auth chain difference for the unpersisted events
# and b) work out the state sets to pass to the store.
#
# Note: If the `event_map` is empty (which is the common case), we can do a
# much simpler calculation.
Comment on lines +288 to +289
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may as well stick the events_to_auth_chain logic inside the if event_map condition too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it harder to follow with all the extra indents TBH. I don't think it has much of a performance impact since its just a case of creating an empty dict and iterating over an empty dict.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I'm not worried about the performance impact. It just feels odd to me that we've decided we're going to optimise part of the algorithm for the "event_map is empty" case, and not the other half - for me that makes it harder to follow.

I don't feel strongly though: up to you.

difference_from_event_map = set()
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
if event_map:
state_sets_ids = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, some type annotations might be helpful, along with a comment about what it's going to contain

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added comments, but I struggled a bit expressing what I mean in words.

unpersisted_set_ids = []
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

for state_set in state_sets:
set_ids = set() # type: Set[str]
state_sets_ids.append(set_ids)

unpersisted_ids = set() # type: Set[str]
unpersisted_set_ids.append(unpersisted_ids)

for event_id in state_set.values():
event_chain = events_to_auth_chain.get(event_id)
if event_chain is not None:
# We have an event in `event_map`. We add all the auth
# events that it reference's (that aren't also in `event_map`).
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
set_ids.update(e for e in event_chain if e not in event_map)

# We also add the full chain of unpersisted event IDs
# referenced by this state set, so that we can work out the
# auth chain difference of the unpersisted events.
unpersisted_ids.update(e for e in event_chain if e in event_map)
else:
set_ids.add(event_id)

# The auth chain difference of the unpersisted events of the state sets
# is calcualted by taking the difference between the union and
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
# intersections.
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])

difference_from_event_map = union - intersection
else:
state_sets_ids = [set(state_set.values()) for state_set in state_sets]

difference = await state_res_store.get_auth_chain_difference(state_sets_ids)
difference.update(difference_from_event_map)

return difference

Expand Down
128 changes: 127 additions & 1 deletion tests/state/test_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
from synapse.events import make_event_from_dict
from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
from synapse.state.v2 import (
_get_auth_chain_difference,
lexicographical_topological_sort,
resolve_events_with_store,
)
from synapse.types import EventID

from tests import unittest
Expand Down Expand Up @@ -587,6 +591,128 @@ def test_event_map_none(self):
self.assert_dict(self.expected_combined_state, state)


class AuthChainDifferenceTestCase(unittest.TestCase):
"""We test that `_get_auth_chain_difference` correctly handles unpersisted
events.
"""

def test_simple(self):
# Test getting the auth difference for a simple chain with a single
# unpersisted event:
#
# Unpersisted | Persisted
# |
# C -|-> B -> A

a = FakeEvent(
id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([], [])

b = FakeEvent(
id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([a.event_id], [])

c = FakeEvent(
id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([b.event_id], [])

persisted_events = {a.event_id: a, b.event_id: b}
unpersited_events = {c.event_id: c}

state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]

store = TestStateResolutionStore(persisted_events)

diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
difference = self.successResultOf(defer.ensureDeferred(diff_d))

self.assertEqual(difference, {c.event_id})

def test_multiple_unpersisted_chain(self):
# Test getting the auth difference for a simple chain with multiple
# unpersisted events:
#
# Unpersisted | Persisted
# |
# D -> C -|-> B -> A

a = FakeEvent(
id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([], [])

b = FakeEvent(
id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([a.event_id], [])

c = FakeEvent(
id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([b.event_id], [])

d = FakeEvent(
id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([c.event_id], [])

persisted_events = {a.event_id: a, b.event_id: b}
unpersited_events = {c.event_id: c, d.event_id: d}

state_sets = [
{"a": a.event_id, "b": b.event_id},
{"c": c.event_id, "d": d.event_id},
]

store = TestStateResolutionStore(persisted_events)

diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
difference = self.successResultOf(defer.ensureDeferred(diff_d))

self.assertEqual(difference, {d.event_id, c.event_id})

def test_unpersisted_events_different_sets(self):
# Test getting the auth difference for with multiple unpersisted events
# in different branches:
#
# Unpersisted | Persisted
# |
# D --> C -|-> B -> A
# E ----^ -|---^
# |

a = FakeEvent(
id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([], [])

b = FakeEvent(
id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([a.event_id], [])

c = FakeEvent(
id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([b.event_id], [])

d = FakeEvent(
id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([c.event_id], [])

e = FakeEvent(
id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
).to_event([c.event_id, b.event_id], [])

persisted_events = {a.event_id: a, b.event_id: b}
unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}

state_sets = [
{"a": a.event_id, "b": b.event_id, "e": e.event_id},
{"c": c.event_id, "d": d.event_id},
]

store = TestStateResolutionStore(persisted_events)

diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
difference = self.successResultOf(defer.ensureDeferred(diff_d))

self.assertEqual(difference, {d.event_id, e.event_id})


def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
Expand Down
5 changes: 5 additions & 0 deletions tests/storage/test_event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ def insert_event(txn, event_id, stream_ordering):
)
self.assertSetEqual(difference, {"a", "b", "c"})

difference = self.get_success(
self.store.get_auth_chain_difference([{"a", "c"}, {"b", "c"}])
)
self.assertSetEqual(difference, {"a", "b"})

difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
)
Expand Down