Skip to content

Commit

Permalink
Revert "Reduce device lists replication traffic. (#17333)"
Browse files Browse the repository at this point in the history
This reverts commit cf711ac.
  • Loading branch information
erikjohnston authored Jun 25, 2024
1 parent 6e8af83 commit 20dea47
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 89 deletions.
1 change: 0 additions & 1 deletion changelog.d/17333.misc

This file was deleted.

19 changes: 7 additions & 12 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,13 @@ async def on_rdata(
"""
all_room_ids: Set[str] = set()
if stream_name == DeviceListsStream.NAME:
if any(not row.is_signature and not row.hosts_calculated for row in rows):
if any(row.entity.startswith("@") and not row.is_signature for row in rows):
prev_token = self.store.get_device_stream_token()
all_room_ids = await self.store.get_all_device_list_changes(
prev_token, token
)
self.store.device_lists_in_rooms_have_changed(all_room_ids, token)

# If we're sending federation we need to update the device lists
# outbound pokes stream change cache with updated hosts.
if self.send_handler and any(row.hosts_calculated for row in rows):
hosts = await self.store.get_destinations_for_device(token)
self.store.device_lists_outbound_pokes_have_changed(hosts, token)

self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
Expand Down Expand Up @@ -439,11 +433,12 @@ async def process_replication_rows(
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
if any(row.hosts_calculated for row in rows):
hosts = await self.store.get_destinations_for_device(token)
await self.federation_sender.send_device_messages(
hosts, immediate=False
)
hosts = {
row.entity
for row in rows
if not row.entity.startswith("@") and not row.is_signature
}
await self.federation_sender.send_device_messages(hosts, immediate=False)

elif stream_name == ToDeviceStream.NAME:
# The to_device stream includes stuff to be pushed to both local
Expand Down
12 changes: 4 additions & 8 deletions synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,14 +549,10 @@ class DeviceListsStream(_StreamFromIdGen):

@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow:
user_id: str
entity: str
# Indicates that a user has signed their own device with their user-signing key
is_signature: bool

# Indicates if this is a notification that we've calculated the hosts we
# need to send the update to.
hosts_calculated: bool

NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow

Expand Down Expand Up @@ -598,13 +594,13 @@ async def _update_function(
upper_limit_token = min(upper_limit_token, signatures_to_token)

device_updates = [
(stream_id, (entity, False, hosts))
for stream_id, (entity, hosts) in device_updates
(stream_id, (entity, False))
for stream_id, (entity,) in device_updates
if stream_id <= upper_limit_token
]

signatures_updates = [
(stream_id, (entity, True, False))
(stream_id, (entity, True))
for stream_id, (entity,) in signatures_updates
if stream_id <= upper_limit_token
]
Expand Down
93 changes: 35 additions & 58 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,24 +164,22 @@ def __init__(
prefilled_cache=user_signature_stream_prefill,
)

self._device_list_federation_stream_cache = None
if hs.should_send_federation():
(
device_list_federation_prefill,
device_list_federation_list_id,
) = self.db_pool.get_cache_dict(
db_conn,
"device_lists_outbound_pokes",
entity_column="destination",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache",
device_list_federation_list_id,
prefilled_cache=device_list_federation_prefill,
)
(
device_list_federation_prefill,
device_list_federation_list_id,
) = self.db_pool.get_cache_dict(
db_conn,
"device_lists_outbound_pokes",
entity_column="destination",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache",
device_list_federation_list_id,
prefilled_cache=device_list_federation_prefill,
)

if hs.config.worker.run_background_tasks:
self._clock.looping_call(
Expand Down Expand Up @@ -209,29 +207,22 @@ def _invalidate_caches_for_devices(
) -> None:
for row in rows:
if row.is_signature:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
self._user_signature_stream_cache.entity_has_changed(row.entity, token)
continue

# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
if not row.hosts_calculated:
self._device_list_stream_cache.entity_has_changed(row.user_id, token)
self.get_cached_devices_for_user.invalidate((row.user_id,))
self._get_cached_user_device.invalidate((row.user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate(
(row.user_id,)
)
if row.entity.startswith("@"):
self._device_list_stream_cache.entity_has_changed(row.entity, token)
self.get_cached_devices_for_user.invalidate((row.entity,))
self._get_cached_user_device.invalidate((row.entity,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))

def device_lists_outbound_pokes_have_changed(
self, destinations: StrCollection, token: int
) -> None:
assert self._device_list_federation_stream_cache is not None

for destination in destinations:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)
else:
self._device_list_federation_stream_cache.entity_has_changed(
row.entity, token
)

def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int
Expand Down Expand Up @@ -372,11 +363,6 @@ async def get_device_updates_by_remote(
EDU contents.
"""
now_stream_id = self.get_device_stream_token()
if from_stream_id == now_stream_id:
return now_stream_id, []

if self._device_list_federation_stream_cache is None:
raise Exception("Func can only be used on federation senders")

has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
Expand Down Expand Up @@ -1032,10 +1018,10 @@ def _get_all_device_list_changes_for_remotes(
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
SELECT stream_id, user_id, hosts FROM (
SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
SELECT stream_id, entity FROM (
SELECT stream_id, user_id AS entity FROM device_lists_stream
UNION ALL
SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
Expand Down Expand Up @@ -1591,14 +1577,6 @@ def get_device_list_changes_in_room_txn(
get_device_list_changes_in_room_txn,
)

async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
return await self.db_pool.simple_select_onecol(
table="device_lists_outbound_pokes",
keyvalues={"stream_id": stream_id},
retcol="destination",
desc="get_destinations_for_device",
)


class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
Expand Down Expand Up @@ -2134,13 +2112,12 @@ def _add_device_outbound_poke_to_stream_txn(
stream_ids: List[int],
context: Optional[Dict[str, str]],
) -> None:
if self._device_list_federation_stream_cache:
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_ids[-1],
)
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_ids[-1],
)

now = self._clock.time_msec()
stream_id_iterator = iter(stream_ids)
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def process_replication_rows(
if stream_name == DeviceListsStream.NAME:
for row in rows:
assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
if not row.hosts_calculated:
if row.entity.startswith("@"):
self._get_e2e_device_keys_for_federation_query_inner.invalidate(
(row.user_id,)
(row.entity,)
)

super().process_replication_rows(stream_name, instance_name, token, rows)
Expand Down
8 changes: 0 additions & 8 deletions tests/storage/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,6 @@ class DeviceStoreTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main

def default_config(self) -> JsonDict:
config = super().default_config()

# We 'enable' federation otherwise `get_device_updates_by_remote` will
# throw an exception.
config["federation_sender_instances"] = ["master"]
return config

def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
"""Add a device list change for the given device to
`device_lists_outbound_pokes` table.
Expand Down

0 comments on commit 20dea47

Please sign in to comment.