Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add a type hint for get_device_handler() and deal with the fallout. #14055

Merged
merged 16 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14055.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to `HomeServer`.
4 changes: 4 additions & 0 deletions synapse/handlers/deactivate_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import TYPE_CHECKING, Optional

from synapse.api.errors import SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import Codes, Requester, UserID, create_requester

Expand Down Expand Up @@ -76,6 +77,9 @@ async def deactivate_account(
True if identity server supports removing threepids, otherwise False.
"""

# This can only be called on the main process.
assert isinstance(self._device_handler, DeviceHandler)

# Check if this user can be deactivated
if not await self._third_party_rules.check_can_deactivate_user(
user_id, by_admin
Expand Down
39 changes: 37 additions & 2 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@


class DeviceWorkerHandler:
device_list_updater: "DeviceListWorkerUpdater"

def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hs = hs
Expand All @@ -76,6 +78,8 @@ def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled

self.device_list_updater = DeviceListWorkerUpdater(hs)

@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
Expand Down Expand Up @@ -127,7 +131,7 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict:
@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
) -> Collection[str]:
) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
Expand Down Expand Up @@ -320,6 +324,8 @@ async def handle_room_un_partial_stated(self, room_id: str) -> None:


class DeviceHandler(DeviceWorkerHandler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between DeviceHandler and DeviceWorkerHandler now? Has this changed as a result of your patch here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved get_dehydrated_device from DeviceHandler to DeviceWorkerHandler. The only other change is that their device_list_updater are now DeviceListUpdater or DeviceListWorkerUpdater, respectively. The latter has an overridden method (user_device_resync) to call across replication to the main process.

device_list_updater: "DeviceListUpdater"

def __init__(self, hs: "HomeServer"):
super().__init__(hs)

Expand Down Expand Up @@ -858,7 +864,36 @@ def _update_device_from_client_ips(
)


class DeviceListUpdater:
class DeviceListWorkerUpdater:
"Handles incoming device list updates from federation and contacts the main process over replication"

def __init__(self, hs: "HomeServer"):
from synapse.replication.http.devices import (
ReplicationUserDevicesResyncRestServlet,
)

self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)

async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[JsonDict]:
"""Fetches all devices for a user and updates the device cache with them.

Args:
user_id: The user's id whose device_list will be updated.
mark_failed_as_stale: Whether to mark the user's device list as stale
if the attempt to resync failed.
Returns:
A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
return await self._user_device_resync_client(user_id=user_id)


class DeviceListUpdater(DeviceListWorkerUpdater):
"Handles incoming device list updates from federation and updates the DB"

def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
Expand Down
61 changes: 32 additions & 29 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@

from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
JsonDict,
UserID,
Expand All @@ -56,27 +56,23 @@ def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.clock = hs.get_clock()

self._edu_updater = SigningKeyEduUpdater(hs, self)

federation_registry = hs.get_federation_registry()

self._is_master = hs.config.worker.worker_app is None
if not self._is_master:
self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
else:
is_master = hs.config.worker.worker_app is None
if is_master:
edu_updater = SigningKeyEduUpdater(hs)

# Only register this edu handler on master as it requires writing
# device updates to the db
federation_registry.register_edu_handler(
EduTypes.SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update,
edu_updater.incoming_signing_key_update,
)
# also handle the unstable version
# FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler(
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update,
edu_updater.incoming_signing_key_update,
)

# doesn't really work as part of the generic query API, because the
Expand Down Expand Up @@ -319,14 +315,13 @@ async def _query_devices_for_destination(
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
if self._is_master:
resync_results = await self.device_handler.device_list_updater.user_device_resync(
resync_results = (
await self.device_handler.device_list_updater.user_device_resync(
user_id
)
else:
resync_results = await self._user_device_resync_client(
user_id=user_id
)
)
if resync_results is None:
raise ValueError("Device resync failed")

# Add the device keys to the results.
user_devices = resync_results["devices"]
Expand Down Expand Up @@ -605,6 +600,8 @@ async def claim_client_keys(destination: str) -> None:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

time_now = self.clock.time_msec()

Expand Down Expand Up @@ -732,6 +729,8 @@ async def upload_signing_keys_for_user(
user_id: the user uploading the keys
keys: the signing keys
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

# if a master key is uploaded, then check it. Otherwise, load the
# stored master key, to check signatures on other keys
Expand Down Expand Up @@ -823,6 +822,9 @@ async def upload_signatures_for_device_keys(
Raises:
SynapseError: if the signatures dict is not valid.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

failures = {}

# signatures to be stored. Each item will be a SignatureListItem
Expand Down Expand Up @@ -1200,6 +1202,9 @@ async def _retrieve_cross_signing_keys_for_remote_user(
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

try:
remote_result = await self.federation.query_user_devices(
user.domain, user.to_string()
Expand Down Expand Up @@ -1396,11 +1401,14 @@ class SignatureListItem:
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""

def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
self.e2e_keys_handler = e2e_keys_handler

device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self._device_handler = device_handler

self._remote_edu_linearizer = Linearizer(name="remote_signing_key")

Expand Down Expand Up @@ -1445,9 +1453,6 @@ async def _handle_signing_key_updates(self, user_id: str) -> None:
user_id: the user whose updates we are processing
"""

device_handler = self.e2e_keys_handler.device_handler
device_list_updater = device_handler.device_list_updater

async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
Expand All @@ -1459,13 +1464,11 @@ async def _handle_signing_key_updates(self, user_id: str) -> None:
logger.info("pending updates: %r", pending_updates)

for master_key, self_signing_key in pending_updates:
new_device_ids = (
await device_list_updater.process_cross_signing_key_update(
user_id,
master_key,
self_signing_key,
)
new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
user_id,
master_key,
self_signing_key,
)
device_ids = device_ids + new_device_ids

await device_handler.notify_device_update(user_id, device_ids)
await self._device_handler.notify_device_update(user_id, device_ids)
4 changes: 4 additions & 0 deletions synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import (
Expand Down Expand Up @@ -841,6 +842,9 @@ class and RegisterDeviceReplicationServlet.
refresh_token = None
refresh_token_id = None

# This can only run on the main process.
assert isinstance(self.device_handler, DeviceHandler)

registered_device_id = await self.device_handler.check_device_registered(
user_id,
device_id,
Expand Down
6 changes: 5 additions & 1 deletion synapse/handlers/set_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Optional

from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.types import Requester

if TYPE_CHECKING:
Expand All @@ -29,7 +30,10 @@ class SetPasswordHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
# This can only be instantiated on the main process.
device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self._device_handler = device_handler

async def set_password(
self,
Expand Down
9 changes: 9 additions & 0 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
from synapse.handlers.device import DeviceHandler
from synapse.handlers.register import init_counters_for_auth_provider
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
Expand Down Expand Up @@ -1035,13 +1036,21 @@ async def revoke_sessions_for_provider_session_id(
) -> None:
"""Revoke any devices and in-flight logins tied to a provider session.

Can only be called from the main process.

Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
auth_provider_session_id: The session ID from the provider to logout
expected_user_id: The user we're expecting to logout. If set, it will ignore
sessions belonging to other users and log an error.
"""

# It is expected that this is the main process.
assert isinstance(
self._device_handler, DeviceHandler
), "revoking SSO sessions can only be called on the main process"

# Invalidate any running user-mapping sessions
to_delete = []
for session_id, session in self._username_mapping_sessions.items():
Expand Down
10 changes: 9 additions & 1 deletion synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
from synapse.handlers.device import DeviceHandler
from synapse.handlers.push_rules import RuleSpec, check_actions
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
Expand Down Expand Up @@ -207,6 +208,7 @@ def __init__(self, hs: "HomeServer", auth_handler: AuthHandler) -> None:
self._registration_handler = hs.get_registration_handler()
self._send_email_handler = hs.get_send_email_handler()
self._push_rules_handler = hs.get_push_rules_handler()
self._device_handler = hs.get_device_handler()
self.custom_template_dir = hs.config.server.custom_template_directory

try:
Expand Down Expand Up @@ -784,6 +786,8 @@ def invalidate_access_token(
) -> Generator["defer.Deferred[Any]", Any, None]:
"""Invalidate an access token for a user

Can only be called from the main process.
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

Added in Synapse v0.25.0.

Args:
Expand All @@ -796,6 +800,10 @@ def invalidate_access_token(
Raises:
synapse.api.errors.AuthError: the access token is invalid
"""
assert isinstance(
self._device_handler, DeviceHandler
), "invalidate_access_token can only be called on the main process"

# see if the access token corresponds to a device
user_info = yield defer.ensureDeferred(
self._auth.get_user_by_access_token(access_token)
Expand All @@ -805,7 +813,7 @@ def invalidate_access_token(
if device_id:
# delete the device, which will also delete its access tokens
yield defer.ensureDeferred(
self._hs.get_device_handler().delete_devices(user_id, [device_id])
self._device_handler.delete_devices(user_id, [device_id])
)
else:
# no associated device. Just delete the access token.
Expand Down
11 changes: 8 additions & 3 deletions synapse/replication/http/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Optional, Tuple

from twisted.web.server import Request

Expand Down Expand Up @@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.device_list_updater = hs.get_device_handler().device_list_updater
from synapse.handlers.device import DeviceHandler

handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_list_updater = handler.device_list_updater

self.store = hs.get_datastores().main
self.clock = hs.get_clock()

Expand All @@ -73,7 +78,7 @@ async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
) -> Tuple[int, Optional[JsonDict]]:
user_devices = await self.device_list_updater.user_device_resync(user_id)

return 200, user_devices
Expand Down
Loading