Skip to content

Commit

Permalink
Fix bug where sync could get stuck when using workers (#17438)
Browse files Browse the repository at this point in the history
This is because we serialized the token wrong if the instance map
contained entries from before the minimum token.
  • Loading branch information
erikjohnston authored Jul 15, 2024
1 parent d88ba45 commit df11af1
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 10 deletions.
1 change: 1 addition & 0 deletions changelog.d/17438.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix rare bug where `/sync` would break for a user when using workers with multiple stream writers.
11 changes: 9 additions & 2 deletions synapse/handlers/sliding_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,10 +699,17 @@ async def get_room_membership_for_user_at_to_token(
instance_to_max_stream_ordering_map[instance_name] = stream_ordering

# Then assemble the `RoomStreamToken`
min_stream_pos = min(instance_to_max_stream_ordering_map.values())
membership_snapshot_token = RoomStreamToken(
# Minimum position in the `instance_map`
stream=min(instance_to_max_stream_ordering_map.values()),
instance_map=immutabledict(instance_to_max_stream_ordering_map),
stream=min_stream_pos,
instance_map=immutabledict(
{
instance_name: stream_pos
for instance_name, stream_pos in instance_to_max_stream_ordering_map.items()
if stream_pos > min_stream_pos
}
),
)

# Since we fetched the users room list at some point in time after the from/to
Expand Down
65 changes: 57 additions & 8 deletions synapse/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#
#
import abc
import logging
import re
import string
from enum import Enum
Expand Down Expand Up @@ -74,6 +75,9 @@
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore


logger = logging.getLogger(__name__)

# Define a state map type from type/state_key to T (usually an event ID or
# event)
T = TypeVar("T")
Expand Down Expand Up @@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
represented by a default `stream` attribute and a map of instance name to
stream position of any writers that are ahead of the default stream
position.
The values in `instance_map` must be greater than the `stream` attribute.
"""

stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
Expand All @@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
kw_only=True,
)

def __attrs_post_init__(self) -> None:
# Enforce that all instances have a value greater than the min stream
# position.
for i, v in self.instance_map.items():
if v <= self.stream:
raise ValueError(
f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}"
)

@classmethod
@abc.abstractmethod
async def parse(cls, store: "DataStore", string: str) -> "Self":
Expand All @@ -494,6 +509,9 @@ def copy_and_advance(self, other: "Self") -> "Self":
for instance in set(self.instance_map).union(other.instance_map)
}

# Filter out any redundant entries.
instance_map = {i: s for i, s in instance_map.items() if s > max_stream}

return attr.evolve(
self, stream=max_stream, instance_map=immutabledict(instance_map)
)
Expand Down Expand Up @@ -539,10 +557,15 @@ def is_before_or_eq(self, other_token: Self) -> bool:
def bound_stream_token(self, max_stream: int) -> "Self":
"""Bound the stream positions to a maximum value"""

min_pos = min(self.stream, max_stream)
return type(self)(
stream=min(self.stream, max_stream),
stream=min_pos,
instance_map=immutabledict(
{k: min(s, max_stream) for k, s in self.instance_map.items()}
{
k: min(s, max_stream)
for k, s in self.instance_map.items()
if min(s, max_stream) > min_pos
}
),
)

Expand Down Expand Up @@ -637,6 +660,8 @@ def __attrs_post_init__(self) -> None:
"Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
)

super().__attrs_post_init__()

@classmethod
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
try:
Expand All @@ -651,6 +676,11 @@ async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken

instance_map = {}
for part in parts[1:]:
if not part:
# Handle tokens of the form `m5~`, which were created by
# a bug
continue

key, value = part.split(".")
instance_id = int(key)
pos = int(value)
Expand All @@ -666,7 +696,10 @@ async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken
except CancelledError:
raise
except Exception:
pass
# We log an exception here as even though this *might* be a client
# handing a bad token, its more likely that Synapse returned a bad
# token (and we really want to catch those!).
logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid room stream token %r" % (string,))

@classmethod
Expand Down Expand Up @@ -713,6 +746,8 @@ def get_stream_pos_for_instance(self, instance_name: str) -> int:
return self.instance_map.get(instance_name, self.stream)

async def to_string(self, store: "DataStore") -> str:
"""See class level docstring for information about the format."""

if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
elif self.instance_map:
Expand All @@ -727,8 +762,10 @@ async def to_string(self, store: "DataStore") -> str:
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")

encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
if entries:
encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
return f"s{self.stream}"
else:
return "s%d" % (self.stream,)

Expand Down Expand Up @@ -756,6 +793,11 @@ async def parse(cls, store: "DataStore", string: str) -> "MultiWriterStreamToken

instance_map = {}
for part in parts[1:]:
if not part:
# Handle tokens of the form `m5~`, which were created by
# a bug
continue

key, value = part.split(".")
instance_id = int(key)
pos = int(value)
Expand All @@ -770,10 +812,15 @@ async def parse(cls, store: "DataStore", string: str) -> "MultiWriterStreamToken
except CancelledError:
raise
except Exception:
pass
# We log an exception here as even though this *might* be a client
# handing a bad token, its more likely that Synapse returned a bad
# token (and we really want to catch those!).
logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid stream token %r" % (string,))

async def to_string(self, store: "DataStore") -> str:
"""See class level docstring for information about the format."""

if self.instance_map:
entries = []
for name, pos in self.instance_map.items():
Expand All @@ -786,8 +833,10 @@ async def to_string(self, store: "DataStore") -> str:
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")

encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
if entries:
encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
return str(self.stream)
else:
return str(self.stream)

Expand Down
71 changes: 71 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,26 @@
#
#

from typing import Type
from unittest import skipUnless

from immutabledict import immutabledict
from parameterized import parameterized_class

from synapse.api.errors import SynapseError
from synapse.types import (
AbstractMultiWriterStreamToken,
MultiWriterStreamToken,
RoomAlias,
RoomStreamToken,
UserID,
get_domain_from_id,
get_localpart_from_id,
map_username_to_mxid_localpart,
)

from tests import unittest
from tests.utils import USE_POSTGRES_FOR_TESTS


class IsMineIDTests(unittest.HomeserverTestCase):
Expand Down Expand Up @@ -127,3 +137,64 @@ def test_non_ascii(self) -> None:
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")


@parameterized_class(
("token_type",),
[
(MultiWriterStreamToken,),
(RoomStreamToken,),
],
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
)
class MultiWriterTokenTestCase(unittest.HomeserverTestCase):
"""Tests for the different types of multi writer tokens."""

token_type: Type[AbstractMultiWriterStreamToken]

def test_basic_token(self) -> None:
"""Test that a simple stream token can be serialized and unserialized"""
store = self.hs.get_datastores().main

token = self.token_type(stream=5)

string_token = self.get_success(token.to_string(store))

if isinstance(token, RoomStreamToken):
self.assertEqual(string_token, "s5")
else:
self.assertEqual(string_token, "5")

parsed_token = self.get_success(self.token_type.parse(store, string_token))
self.assertEqual(parsed_token, token)

@skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres")
def test_instance_map(self) -> None:
"""Test for stream token with instance map"""
store = self.hs.get_datastores().main

token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6}))

string_token = self.get_success(token.to_string(store))
self.assertEqual(string_token, "m5~1.6")

parsed_token = self.get_success(self.token_type.parse(store, string_token))
self.assertEqual(parsed_token, token)

def test_instance_map_assertion(self) -> None:
"""Test that we assert values in the instance map are greater than the
min stream position"""

with self.assertRaises(ValueError):
self.token_type(stream=5, instance_map=immutabledict({"foo": 4}))

with self.assertRaises(ValueError):
self.token_type(stream=5, instance_map=immutabledict({"foo": 5}))

def test_parse_bad_token(self) -> None:
"""Test that we can parse tokens produced by a bug in Synapse of the
form `m5~`"""
store = self.hs.get_datastores().main

parsed_token = self.get_success(self.token_type.parse(store, "m5~"))
self.assertEqual(parsed_token, self.token_type(stream=5))

0 comments on commit df11af1

Please sign in to comment.