Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add a type hint for get_device_handler() and deal with the fallout. #14055

Merged
merged 16 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14055.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to `HomeServer`.
6 changes: 5 additions & 1 deletion synapse/handlers/deactivate_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import TYPE_CHECKING, Optional

from synapse.api.errors import SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import Codes, Requester, UserID, create_requester

Expand All @@ -32,7 +33,10 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.hs = hs
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
# This can only be instantiated on the main process.
device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self._device_handler = device_handler
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_identity_handler()
self._profile_handler = hs.get_profile_handler()
Expand Down
39 changes: 37 additions & 2 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@


class DeviceWorkerHandler:
device_list_updater: "DeviceListWorkerUpdater"

def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hs = hs
Expand All @@ -76,6 +78,8 @@ def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled

self.device_list_updater = DeviceListWorkerUpdater(hs)

@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
Expand Down Expand Up @@ -127,7 +131,7 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict:
@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
) -> Collection[str]:
) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
Expand Down Expand Up @@ -320,6 +324,8 @@ async def handle_room_un_partial_stated(self, room_id: str) -> None:


class DeviceHandler(DeviceWorkerHandler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between DeviceHandler and DeviceWorkerHandler now? Has this changed as a result of your patch here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved get_dehydrated_device from DeviceHandler to DeviceWorkerHandler. The only other change is that their device_list_updater are now DeviceListUpdater or DeviceListWorkerUpdater, respectively. The latter has an overridden method (user_device_resync) to call across replication to the main process.

device_list_updater: "DeviceListUpdater"

def __init__(self, hs: "HomeServer"):
super().__init__(hs)

Expand Down Expand Up @@ -858,7 +864,36 @@ def _update_device_from_client_ips(
)


class DeviceListUpdater:
class DeviceListWorkerUpdater:
"Handles incoming device list updates from federation and contacts the main process over replication"

def __init__(self, hs: "HomeServer"):
from synapse.replication.http.devices import (
ReplicationUserDevicesResyncRestServlet,
)

self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)

async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[JsonDict]:
"""Fetches all devices for a user and updates the device cache with them.

Args:
user_id: The user's id whose device_list will be updated.
mark_failed_as_stale: Whether to mark the user's device list as stale
if the attempt to resync failed.
Returns:
A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
return await self._user_device_resync_client(user_id=user_id)


class DeviceListUpdater(DeviceListWorkerUpdater):
"Handles incoming device list updates from federation and updates the DB"

def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
Expand Down
61 changes: 32 additions & 29 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@

from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
JsonDict,
UserID,
Expand All @@ -55,27 +55,23 @@ def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.clock = hs.get_clock()

self._edu_updater = SigningKeyEduUpdater(hs, self)

federation_registry = hs.get_federation_registry()

self._is_master = hs.config.worker.worker_app is None
if not self._is_master:
self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
else:
is_master = hs.config.worker.worker_app is None
if is_master:
edu_updater = SigningKeyEduUpdater(hs)

# Only register this edu handler on master as it requires writing
# device updates to the db
federation_registry.register_edu_handler(
EduTypes.SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update,
edu_updater.incoming_signing_key_update,
)
# also handle the unstable version
# FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler(
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update,
edu_updater.incoming_signing_key_update,
)

# doesn't really work as part of the generic query API, because the
Expand Down Expand Up @@ -318,14 +314,13 @@ async def _query_devices_for_destination(
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
if self._is_master:
resync_results = await self.device_handler.device_list_updater.user_device_resync(
resync_results = (
await self.device_handler.device_list_updater.user_device_resync(
user_id
)
else:
resync_results = await self._user_device_resync_client(
user_id=user_id
)
)
if resync_results is None:
raise ValueError("Device resync failed")

# Add the device keys to the results.
user_devices = resync_results["devices"]
Expand Down Expand Up @@ -576,6 +571,8 @@ async def claim_client_keys(destination: str) -> None:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

time_now = self.clock.time_msec()

Expand Down Expand Up @@ -703,6 +700,8 @@ async def upload_signing_keys_for_user(
user_id: the user uploading the keys
keys: the signing keys
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

# if a master key is uploaded, then check it. Otherwise, load the
# stored master key, to check signatures on other keys
Expand Down Expand Up @@ -794,6 +793,9 @@ async def upload_signatures_for_device_keys(
Raises:
SynapseError: if the signatures dict is not valid.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

failures = {}

# signatures to be stored. Each item will be a SignatureListItem
Expand Down Expand Up @@ -1171,6 +1173,9 @@ async def _retrieve_cross_signing_keys_for_remote_user(
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

try:
remote_result = await self.federation.query_user_devices(
user.domain, user.to_string()
Expand Down Expand Up @@ -1367,11 +1372,14 @@ class SignatureListItem:
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""

def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
self.e2e_keys_handler = e2e_keys_handler

device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self._device_handler = device_handler

self._remote_edu_linearizer = Linearizer(name="remote_signing_key")

Expand Down Expand Up @@ -1416,9 +1424,6 @@ async def _handle_signing_key_updates(self, user_id: str) -> None:
user_id: the user whose updates we are processing
"""

device_handler = self.e2e_keys_handler.device_handler
device_list_updater = device_handler.device_list_updater

async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
Expand All @@ -1430,13 +1435,11 @@ async def _handle_signing_key_updates(self, user_id: str) -> None:
logger.info("pending updates: %r", pending_updates)

for master_key, self_signing_key in pending_updates:
new_device_ids = (
await device_list_updater.process_cross_signing_key_update(
user_id,
master_key,
self_signing_key,
)
new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
user_id,
master_key,
self_signing_key,
)
device_ids = device_ids + new_device_ids

await device_handler.notify_device_update(user_id, device_ids)
await self._device_handler.notify_device_update(user_id, device_ids)
4 changes: 4 additions & 0 deletions synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import (
Expand Down Expand Up @@ -841,6 +842,9 @@ class and RegisterDeviceReplicationServlet.
refresh_token = None
refresh_token_id = None

# This can only run on the main process.
assert isinstance(self.device_handler, DeviceHandler)

registered_device_id = await self.device_handler.check_device_registered(
user_id,
device_id,
Expand Down
6 changes: 5 additions & 1 deletion synapse/handlers/set_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Optional

from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.types import Requester

if TYPE_CHECKING:
Expand All @@ -29,7 +30,10 @@ class SetPasswordHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
# This can only be instantiated on the main process.
device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self._device_handler = device_handler

async def set_password(
self,
Expand Down
10 changes: 9 additions & 1 deletion synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
from synapse.handlers.device import DeviceHandler
from synapse.handlers.push_rules import RuleSpec, check_actions
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
Expand Down Expand Up @@ -207,6 +208,7 @@ def __init__(self, hs: "HomeServer", auth_handler: AuthHandler) -> None:
self._registration_handler = hs.get_registration_handler()
self._send_email_handler = hs.get_send_email_handler()
self._push_rules_handler = hs.get_push_rules_handler()
self._device_handler = hs.get_device_handler()
self.custom_template_dir = hs.config.server.custom_template_directory

try:
Expand Down Expand Up @@ -781,6 +783,8 @@ def invalidate_access_token(
) -> Generator["defer.Deferred[Any]", Any, None]:
"""Invalidate an access token for a user

Can only be called from the main process.
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

Added in Synapse v0.25.0.

Args:
Expand All @@ -793,6 +797,10 @@ def invalidate_access_token(
Raises:
synapse.api.errors.AuthError: the access token is invalid
"""
assert isinstance(
self._device_handler, DeviceHandler
), "invalidate_access_token can only be called on the main process"

# see if the access token corresponds to a device
user_info = yield defer.ensureDeferred(
self._auth.get_user_by_access_token(access_token)
Expand All @@ -802,7 +810,7 @@ def invalidate_access_token(
if device_id:
# delete the device, which will also delete its access tokens
yield defer.ensureDeferred(
self._hs.get_device_handler().delete_devices(user_id, [device_id])
self._device_handler.delete_devices(user_id, [device_id])
)
else:
# no associated device. Just delete the access token.
Expand Down
11 changes: 8 additions & 3 deletions synapse/replication/http/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Optional, Tuple

from twisted.web.server import Request

Expand Down Expand Up @@ -62,7 +62,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.device_list_updater = hs.get_device_handler().device_list_updater
from synapse.handlers.device import DeviceHandler

handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_list_updater = handler.device_list_updater

self.store = hs.get_datastores().main
self.clock = hs.get_clock()

Expand All @@ -72,7 +77,7 @@ async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
) -> Tuple[int, Optional[JsonDict]]:
user_devices = await self.device_list_updater.user_device_resync(user_id)

return 200, user_devices
Expand Down
6 changes: 3 additions & 3 deletions synapse/rest/admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserTokenRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
UserMediaStatisticsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
Expand All @@ -280,6 +277,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:

# Some servlets only get registered for the main process.
if hs.config.worker.worker_app is None:
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
BackgroundUpdateEnabledRestServlet(hs).register(http_server)
BackgroundUpdateRestServlet(hs).register(http_server)
Expand Down
Loading