Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix a bug introduced in Synapse v1.50.0rc1 whereby outbound federatio…
Browse files Browse the repository at this point in the history
…n could fail because too many EDUs were produced for device updates. (#11730)

Co-authored-by: David Robertson <davidr@element.io>
  • Loading branch information
reivilibre and David Robertson authored Jan 13, 2022
1 parent 22abfca commit b602ba1
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 17 deletions.
1 change: 1 addition & 0 deletions changelog.d/11730.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug introduced in Synapse v1.50.0rc1 whereby outbound federation could fail because too many EDUs were produced for device updates.
94 changes: 78 additions & 16 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ async def get_devices_by_auth_provider_session_id(
@trace
async def get_device_updates_by_remote(
self, destination: str, from_stream_id: int, limit: int
) -> Tuple[int, List[Tuple[str, dict]]]:
) -> Tuple[int, List[Tuple[str, JsonDict]]]:
"""Get a stream of device updates to send to the given remote server.
Args:
Expand All @@ -200,9 +200,10 @@ async def get_device_updates_by_remote(
limit: Maximum number of device updates to return
Returns:
A mapping from the current stream id (ie, the stream id of the last
update included in the response), and the list of updates, where
each update is a pair of EDU type and EDU contents.
- The current stream id (i.e. the stream id of the last update included
in the response); and
- The list of updates, where each update is a pair of EDU type and
EDU contents.
"""
now_stream_id = self.get_device_stream_token()

Expand All @@ -221,6 +222,9 @@ async def get_device_updates_by_remote(
limit,
)

# We need to ensure `updates` doesn't grow too big.
# Currently: `len(updates) <= limit`.

# Return an empty list if there are no updates
if not updates:
return now_stream_id, []
Expand Down Expand Up @@ -277,40 +281,88 @@ async def get_device_updates_by_remote(
query_map = {}
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
if (
# Calculate the remaining length budget.
# Note that, for now, each entry in `cross_signing_keys_by_user`
# gives rise to two device updates in the result, so those cost twice
# as much (and are the whole reason we need to separately calculate
# the budget; we know len(updates) <= limit otherwise!)
# N.B. len() on dicts is cheap since they store their size.
remaining_length_budget = limit - (
len(query_map) + 2 * len(cross_signing_keys_by_user)
)
assert remaining_length_budget >= 0

is_master_key_update = (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
):
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["master_key"] = master_key_by_user[user_id]["key_info"]
elif (
)
is_self_signing_key_update = (
user_id in self_signing_key_by_user
and device_id == self_signing_key_by_user[user_id]["device_id"]
)

is_cross_signing_key_update = (
is_master_key_update or is_self_signing_key_update
)

if (
is_cross_signing_key_update
and user_id not in cross_signing_keys_by_user
):
# This will give rise to 2 device updates.
# If we don't have the budget, stop here!
if remaining_length_budget < 2:
break

if is_master_key_update:
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["master_key"] = master_key_by_user[user_id]["key_info"]
elif is_self_signing_key_update:
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["self_signing_key"] = self_signing_key_by_user[user_id][
"key_info"
]
else:
key = (user_id, device_id)

if key not in query_map and remaining_length_budget < 1:
# We don't have space for a new entry
break

previous_update_stream_id, _ = query_map.get(key, (0, None))

if update_stream_id > previous_update_stream_id:
# FIXME If this overwrites an older update, this discards the
# previous OpenTracing context.
# It might make it harder to track down issues using OpenTracing.
# If there's a good reason why it doesn't matter, a comment here
# about that would not hurt.
query_map[key] = (update_stream_id, update_context)

# As this update has been added to the response, advance the stream
# position.
last_processed_stream_id = update_stream_id

# In the worst case scenario, each update is for a distinct user and is
# added either to the query_map or to cross_signing_keys_by_user,
# but not both:
# len(query_map) + len(cross_signing_keys_by_user) <= len(updates) here,
# so len(query_map) + len(cross_signing_keys_by_user) <= limit.

results = await self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)

# add the updated cross-signing keys to the results list
# len(results) <= len(query_map) here,
# so len(results) + len(cross_signing_keys_by_user) <= limit.

# Add the updated cross-signing keys to the results list
for user_id, result in cross_signing_keys_by_user.items():
result["user_id"] = user_id
results.append(("m.signing_key_update", result))
# also send the unstable version
# FIXME: remove this when enough servers have upgraded
# and remove the length budgeting above.
results.append(("org.matrix.signing_key_update", result))

return last_processed_stream_id, results
Expand All @@ -322,7 +374,7 @@ def _get_device_updates_by_remote_txn(
from_stream_id: int,
now_stream_id: int,
limit: int,
):
) -> List[Tuple[str, str, int, Optional[str]]]:
"""Return device update information for a given remote destination
Args:
Expand All @@ -333,7 +385,11 @@ def _get_device_updates_by_remote_txn(
limit: Maximum number of device updates to return
Returns:
List: List of device updates
List: List of device update tuples:
- user_id
- device_id
- stream_id
- opentracing_context
"""
# get the list of device updates that need to be sent
sql = """
Expand All @@ -357,15 +413,21 @@ async def _get_device_update_edus_by_remote(
Args:
destination: The host the device updates are intended for
from_stream_id: The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
user_id/device_id to update stream_id and the relevant json-encoded
opentracing context
query_map: Dictionary mapping (user_id, device_id) to
(update stream_id, the relevant json-encoded opentracing context)
Returns:
List of objects representing an device update EDU
List of objects representing a device update EDU.
Postconditions:
The returned list has a length not exceeding that of the query_map:
len(result) <= len(query_map)
"""
devices = (
await self.get_e2e_device_keys_and_signatures(
# Because these are (user_id, device_id) tuples with all
# device_ids not being None, the returned list's length will not
# exceed that of query_map.
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
Expand Down
112 changes: 111 additions & 1 deletion tests/storage/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_get_device_updates_by_remote_can_limit_properly(self):
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)

# Get all device updates ever meant for this remote
# Get device updates meant for this remote
next_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=3)
)
Expand Down Expand Up @@ -155,6 +155,116 @@ def test_get_device_updates_by_remote_can_limit_properly(self):
# Check the newly-added device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)

# Check there are no more device updates left.
_, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
)
self.assertEqual(device_updates, [])

def test_get_device_updates_by_remote_cross_signing_key_updates(
self,
) -> None:
"""
Tests that `get_device_updates_by_remote` limits the length of the return value
properly when cross-signing key updates are present.
Current behaviour is that the cross-signing key updates will always come in pairs,
even if that means leaving an earlier batch one EDU short of the limit.
"""

assert self.hs.is_mine_id(
"@user_id:test"
), "Test not valid: this MXID should be considered local"

self.get_success(
self.store.set_e2e_cross_signing_key(
"@user_id:test",
"master",
{
"keys": {
"ed25519:fakeMaster": "aaafakefakefake1AAAAAAAAAAAAAAAAAAAAAAAAAAA="
},
"signatures": {
"@user_id:test": {
"ed25519:fake2": "aaafakefakefake2AAAAAAAAAAAAAAAAAAAAAAAAAAA="
}
},
},
)
)
self.get_success(
self.store.set_e2e_cross_signing_key(
"@user_id:test",
"self_signing",
{
"keys": {
"ed25519:fakeSelfSigning": "aaafakefakefake3AAAAAAAAAAAAAAAAAAAAAAAAAAA="
},
"signatures": {
"@user_id:test": {
"ed25519:fake4": "aaafakefakefake4AAAAAAAAAAAAAAAAAAAAAAAAAAA="
}
},
},
)
)

# Add some device updates with sequential `stream_id`s
# Note that the public cross-signing keys occupy the same space as device IDs,
# so also notify that those have updated.
device_ids = [
"device_id1",
"device_id2",
"fakeMaster",
"fakeSelfSigning",
]

self.get_success(
self.store.add_device_change_to_streams(
"@user_id:test", device_ids, ["somehost"]
)
)

# Get device updates meant for this remote
next_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=3)
)

# Here we expect the device updates for `device_id1` and `device_id2`.
# That means we only receive 2 updates this time around.
# If we had a higher limit, we would expect to see the pair of
# (unstable-prefixed & unprefixed) signing key updates for the device
# represented by `fakeMaster` and `fakeSelfSigning`.
# Our implementation only sends these two variants together, so we get
# a short batch.
self.assertEqual(len(device_updates), 2, device_updates)

# Check the first two devices (device_id1, device_id2) came out.
self._check_devices_in_updates(device_ids[:2], device_updates)

# Get more device updates meant for this remote
next_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
)

# The next 2 updates should be a cross-signing key update
# (the master key update and the self-signing key update are combined into
# one 'signing key update', but the cross-signing key update is emitted
# twice, once with an unprefixed type and once again with an unstable-prefixed type)
# (This is a temporary arrangement for backwards compatibility!)
self.assertEqual(len(device_updates), 2, device_updates)
self.assertEqual(
device_updates[0][0], "m.signing_key_update", device_updates[0]
)
self.assertEqual(
device_updates[1][0], "org.matrix.signing_key_update", device_updates[1]
)

# Check there are no more device updates left.
_, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
)
self.assertEqual(device_updates, [])

def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
Expand Down

0 comments on commit b602ba1

Please sign in to comment.