Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Support MSC3814: Dehydrated Devices Part 2 #16010

Merged
merged 14 commits into from
Aug 8, 2023
Merged
1 change: 1 addition & 0 deletions changelog.d/16010.misc
clokep marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update dehydrated devices implementation.
clokep marked this conversation as resolved.
Show resolved Hide resolved
137 changes: 132 additions & 5 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
Tuple,
)

from canonicaljson import encode_canonical_json

from synapse.api import errors
from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import (
Expand Down Expand Up @@ -385,6 +387,7 @@ def __init__(self, hs: "HomeServer"):
self.federation_sender = hs.get_federation_sender()
self._account_data_handler = hs.get_account_data_handler()
self._storage_controllers = hs.get_storage_controllers()
self.db_pool = hs.get_datastores().main.db_pool

self.device_list_updater = DeviceListUpdater(hs, self)

Expand Down Expand Up @@ -656,15 +659,17 @@ async def store_dehydrated_device(
device_id: Optional[str],
device_data: JsonDict,
initial_device_display_name: Optional[str] = None,
keys_for_device: Optional[JsonDict] = None,
) -> str:
"""Store a dehydrated device for a user. If the user had a previous
dehydrated device, it is removed.
"""Store a dehydrated device for a user, optionally storing the keys associated with
it as well. If the user had a previous dehydrated device, it is removed.

Args:
user_id: the user that we are storing the device for
device_id: device id supplied by client
device_data: the dehydrated device information
initial_device_display_name: The display name to use for the device
keys_for_device: keys for the dehydrated device
Returns:
device id of the dehydrated device
"""
Expand All @@ -673,13 +678,135 @@ async def store_dehydrated_device(
device_id,
initial_device_display_name,
)
old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data
)

time_now = self.clock.time_msec()

if keys_for_device:
keys = await self._check_and_prepare_keys_for_dehydrated_device(
user_id, device_id, keys_for_device
)
old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data, time_now, keys
)
else:
old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data, time_now
)
H-Shay marked this conversation as resolved.
Show resolved Hide resolved

if old_device_id is not None:
await self.delete_devices(user_id, [old_device_id])

return device_id

async def _check_and_prepare_keys_for_dehydrated_device(
self, user_id: str, device_id: str, keys: JsonDict
) -> dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't we creating the dehydrated device now? How can it already have keys? Aren't devices unique per user/device ID?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is creating/storing the dehydrated device - to your question of how it can already have keys I actually don't know - per the MSC the keys were to be uploaded over the /keys/upload endpoint, implying that they already exist: "After the dehydrated device is uploaded, the client will upload the encryption
keys using POST /keys/upload/{device_id}, where the device_id parameter is
the device ID given in the response to PUT /dehydrated_device"

#15929 changed this so that the key upload was integrated into the PUT endpoint, and this PR is further refining the key upload such that storing the device and storing the keys are part of one single database transaction, addressing @poljar's concerns around storing the dehdrated device and it's keys being atomic.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uploading the public keys at the time of the dehydrated device creation tries to address this race: matrix-org/matrix-spec-proposals#3814 (comment).

As to how can it already have keys, well if we remember what a dehydrated device is then it becomes quite clear. A dehydrated device are the private identity and one-time keys of a device, of course they are encrypted so the server can't access them.

It becomes quite natural, that we would like to upload the public parts of those same keys to the server at the same time, once we realize this. I think that doing it in a separate POST /keys/upload/ call is a mistake and quite weird from an API point of view, which the above mentioned race condition showcases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but part of this function checks if we already have keys uploaded for this device (and bails if they're not exactly equal to the new keys). I still don't follow the order of events that would let you upload keys before the device is created.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be me being overzealous here - I basically made sure that all the checks that happen in the /keys/upload pathway happen in this pathway, but maybe that's not necessary and we can just store the keys without checking them?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part of it would be straightforward (Move _set_e2e_device_keys_txn to a method (instead of an inner function) and call it from both places.)

Although it seems this entire function is just a copy of upload_keys_for_user -- can we just call that instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way we use the API, they are guaranteed to be unique. Furthermore since we only upload keys once for a dehydrated device there's little chance of getting this wrong.

I think if this is the case we can probably call the store methods that set the values directly and avoid the complicated logic of reusing keys.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call the store methods that set the values directly

so the reason I didn't do this was that my understanding of what was being asked was that the storing of the keys and the storing of the device needs to all happen in the same transaction - ie in _store_dehydrated_device_txn - wouldn't calling the store methods open a separate transaction?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They would need to be refactored to take a transaction as the first parameter and current callers would call that also.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right I think this is sorted now, sorry for the confusion.

"""
Check if any of the provided keys are duplicate and raise if they are,
prepare keys for insertion in DB

Args:
user_id: user to store keys for
device_id: the dehydrated device to store keys for
keys: the keys - device_keys, onetime_keys, or fallback keys to store
clokep marked this conversation as resolved.
Show resolved Hide resolved

Returns:
keys that have been checked for duplicates and are ready to be inserted into
DB
"""
keys_to_return: dict = {}
device_keys = keys.get("device_keys", None)
if device_keys:
clokep marked this conversation as resolved.
Show resolved Hide resolved
old_key_json = await self.db_pool.simple_select_one_onecol(
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="key_json",
allow_none=True,
)

# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
new_device_key_json = encode_canonical_json(device_keys).decode("utf-8")

if old_key_json == new_device_key_json:
raise SynapseError(
400,
f"Device key for user_id: {user_id}, device_id {device_id} already stored.",
)

keys_to_return["device_keys"] = new_device_key_json

one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
# import this here to avoid a circular import
from synapse.handlers.e2e_keys import _one_time_keys_match

# make a list of (alg, id, key) tuples
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((algorithm, key_id, key_obj))

# First we check if we have already persisted any of the keys.
existing_key_map = await self.store.get_e2e_one_time_keys(
user_id, device_id, [k_id for _, k_id, _ in key_list]
)

new_one_time_keys = (
[]
) # Keys that we need to insert. (alg, id, json) tuples.
for algorithm, key_id, key in key_list:
ex_json = existing_key_map.get((algorithm, key_id), None)
if ex_json:
if not _one_time_keys_match(ex_json, key):
raise SynapseError(
400,
(
"One time key %s:%s already exists. "
"Old key: %s; new key: %r"
)
% (algorithm, key_id, ex_json, key),
)
else:
new_one_time_keys.append(
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
)
keys_to_return["one_time_keys"] = new_one_time_keys

fallback_keys = keys.get("fallback_keys", None)
if fallback_keys:
new_fallback_keys = {}
# there should actually only be one item in the dict but we iterate nevertheless -
# see _set_e2e_fallback_keys_txn
for key_id, fallback_key in fallback_keys.items():
algorithm, key_id = key_id.split(":", 1)
old_key_json = await self.db_pool.simple_select_one_onecol(
table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
retcol="key_json",
allow_none=True,
)

new_fallback_key_json = encode_canonical_json(fallback_key).decode(
"utf-8"
)

# If the uploaded key is the same as the current fallback key,
# don't do anything. This prevents marking the key as unused if it
# was already used.
if old_key_json == new_fallback_key_json:
raise SynapseError(
400, f"Fallback key {old_key_json} already exists."
)
# TODO: should this be an update? it assumes that there will only be one fallback key
new_fallback_keys[f"{algorithm}:{key_id}"] = fallback_key
keys_to_return["fallback_keys"] = new_fallback_keys
return keys_to_return

async def rehydrate_device(
self, user_id: str, access_token: str, device_id: str
) -> dict:
Expand Down
13 changes: 0 additions & 13 deletions synapse/handlers/devicemessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,19 +367,6 @@ async def get_events_for_dehydrated_device(
errcode=Codes.INVALID_PARAM,
)

# if we have a since token, delete any to-device messages before that token
# (since we now know that the device has received them)
deleted = await self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
logger.debug(
"Deleted %d to-device messages up to %d for user_id %s device_id %s",
deleted,
since_stream_id,
user_id,
device_id,
)

to_token = self.event_sources.get_current_token().to_device_key

messages, stream_id = await self.store.get_messages_for_device(
Expand Down
9 changes: 1 addition & 8 deletions synapse/rest/client/devices.py
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,6 @@ class Config:
async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
submission = parse_and_validate_json_object_from_request(request, self.PutBody)
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()

device_info = submission.dict()
if "device_keys" not in device_info.keys():
Expand All @@ -545,18 +544,12 @@ async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"Device key(s) not found, these must be provided.",
)

# TODO: Those two operations, creating a device and storing the
# device's keys should be atomic.
device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(),
submission.device_id,
submission.device_data.dict(),
submission.initial_device_display_name,
)

# TODO: Do we need to do something with the result here?
await self.key_uploader(
user_id=user_id, device_id=submission.device_id, keys=submission.dict()
device_info,
)

return 200, {"device_id": device_id}
Expand Down
73 changes: 71 additions & 2 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,8 +1188,66 @@ async def get_dehydrated_device(
)

def _store_dehydrated_device_txn(
self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
device_data: str,
time: int,
keys: Optional[JsonDict] = None,
) -> Optional[str]:
# TODO: make keys non-optional once support for msc2697 is dropped
if keys:
device_keys = keys.get("device_keys", None)
if device_keys:
self.db_pool.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
values={"ts_added_ms": time, "key_json": device_keys},
)

one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
self.db_pool.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
keys=(
"user_id",
"device_id",
"algorithm",
"key_id",
"ts_added_ms",
"key_json",
),
values=[
(user_id, device_id, algorithm, key_id, time, json_bytes)
for algorithm, key_id, json_bytes in one_time_keys
],
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)

fallback_keys = keys.get("fallback_keys", None)
if fallback_keys:
for key_id, fallback_key in fallback_keys.items():
algorithm, key_id = key_id.split(":", 1)
self.db_pool.simple_upsert_txn(
txn,
table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
values={
"key_id": key_id,
"key_json": json_encoder.encode(fallback_key),
"used": False,
},
)

old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="dehydrated_devices",
Expand All @@ -1203,26 +1261,37 @@ def _store_dehydrated_device_txn(
keyvalues={"user_id": user_id},
values={"device_id": device_id, "device_data": device_data},
)

return old_device_id

async def store_dehydrated_device(
self, user_id: str, device_id: str, device_data: JsonDict
self,
user_id: str,
device_id: str,
device_data: JsonDict,
time_now: int,
keys: Optional[dict] = None,
) -> Optional[str]:
"""Store a dehydrated device for a user.

Args:
user_id: the user that we are storing the device for
device_id: the ID of the dehydrated device
device_data: the dehydrated device information
time_now: current time at the request
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
keys: keys for the dehydrated device
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
Returns:
device id of the user's previous dehydrated device, if any
"""

return await self.db_pool.runInteraction(
"store_dehydrated_device_txn",
self._store_dehydrated_device_txn,
user_id,
device_id,
json_encoder.encode(device_data),
time_now,
keys,
)

async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
Expand Down
9 changes: 5 additions & 4 deletions tests/handlers/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,15 +566,16 @@ def test_dehydrate_v2_and_fetch_events(self) -> None:
self.assertEqual(len(res["events"]), 1)
self.assertEqual(res["events"][0]["content"]["body"], "foo")

# Fetch the message of the dehydrated device again, which should return nothing
# and delete the old messages
# Fetch the message of the dehydrated device again, which should return
# the same message as it has not been deleted
res = self.get_success(
self.message_handler.get_events_for_dehydrated_device(
requester=requester,
device_id=stored_dehydrated_device_id,
since_token=res["next_batch"],
since_token=None,
limit=10,
)
)
self.assertTrue(len(res["next_batch"]) > 1)
self.assertEqual(len(res["events"]), 0)
self.assertEqual(len(res["events"]), 1)
self.assertEqual(res["events"][0]["content"]["body"], "foo")
Loading
Loading