Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add typing information to the device handler. #8407

Merged
merged 7 commits into from
Oct 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8407.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add typing information to the device handler.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ files =
synapse/federation,
synapse/handlers/auth.py,
synapse/handlers/cas_handler.py,
synapse/handlers/device.py,
synapse/handlers/directory.py,
synapse/handlers/events.py,
synapse/handlers/federation.py,
Expand Down
89 changes: 54 additions & 35 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple

from synapse.api import errors
from synapse.api.constants import EventTypes
Expand All @@ -29,8 +29,10 @@
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import (
Collection,
JsonDict,
StreamToken,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
Expand All @@ -42,13 +44,16 @@

from ._base import BaseHandler

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)

MAX_DEVICE_DISPLAY_NAME_LEN = 100


class DeviceWorkerHandler(BaseHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.hs = hs
Expand Down Expand Up @@ -106,7 +111,9 @@ async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:

@trace
@measure_func("device.get_user_ids_changed")
async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
async def get_user_ids_changed(
self, user_id: str, from_token: StreamToken
) -> JsonDict:
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
"""
Expand Down Expand Up @@ -222,16 +229,16 @@ async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
possibly_joined = possibly_changed & users_who_share_room
possibly_left = (possibly_changed | possibly_left) - users_who_share_room
else:
possibly_joined = []
possibly_left = []
possibly_joined = set()
possibly_left = set()

result = {"changed": list(possibly_joined), "left": list(possibly_left)}

log_kv(result)

return result

async def on_federation_query_user_devices(self, user_id):
async def on_federation_query_user_devices(self, user_id: str) -> JsonDict:
stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
user_id
)
Expand All @@ -250,7 +257,7 @@ async def on_federation_query_user_devices(self, user_id):


class DeviceHandler(DeviceWorkerHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.federation_sender = hs.get_federation_sender()
Expand All @@ -265,7 +272,7 @@ def __init__(self, hs):

hs.get_distributor().observe("user_left_room", self.user_left_room)

def _check_device_name_length(self, name: str):
def _check_device_name_length(self, name: Optional[str]):
"""
Checks whether a device name is longer than the maximum allowed length.

Expand All @@ -284,21 +291,23 @@ def _check_device_name_length(self, name: str):
)

async def check_device_registered(
self, user_id, device_id, initial_device_display_name=None
):
self,
user_id: str,
device_id: Optional[str],
initial_device_display_name: Optional[str] = None,
) -> str:
"""
If the given device has not been registered, register it with the
supplied display name.

If no device_id is supplied, we make one up.

Args:
user_id (str): @user:id
device_id (str | None): device id supplied by client
initial_device_display_name (str | None): device display name from
client
user_id: @user:id
device_id: device id supplied by client
initial_device_display_name: device display name from client
Returns:
str: device id (generated if none was supplied)
device id (generated if none was supplied)
"""

self._check_device_name_length(initial_device_display_name)
Expand All @@ -317,15 +326,15 @@ async def check_device_registered(
# times in case of a clash.
attempts = 0
while attempts < 5:
device_id = stringutils.random_string(10).upper()
new_device_id = stringutils.random_string(10).upper()
new_device = await self.store.store_device(
user_id=user_id,
device_id=device_id,
device_id=new_device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
await self.notify_device_update(user_id, [device_id])
return device_id
await self.notify_device_update(user_id, [new_device_id])
return new_device_id
attempts += 1

raise errors.StoreError(500, "Couldn't generate a device ID.")
Expand Down Expand Up @@ -434,7 +443,9 @@ async def update_device(self, user_id: str, device_id: str, content: dict) -> No

@trace
@measure_func("notify_device_update")
async def notify_device_update(self, user_id, device_ids):
async def notify_device_update(
self, user_id: str, device_ids: Collection[str]
) -> None:
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
"""
Expand All @@ -446,7 +457,7 @@ async def notify_device_update(self, user_id, device_ids):
user_id
)

hosts = set()
hosts = set() # type: Set[str]
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
Expand Down Expand Up @@ -498,7 +509,7 @@ async def notify_user_signature_update(

self.notifier.on_new_event("device_list_key", position, users=[from_user_id])

async def user_left_room(self, user, room_id):
async def user_left_room(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
Expand Down Expand Up @@ -586,15 +597,17 @@ async def rehydrate_device(
return {"success": True}


def _update_device_from_client_ips(device, client_ips):
def _update_device_from_client_ips(
device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})


class DeviceListUpdater:
"Handles incoming device list updates from federation and updates the DB"

def __init__(self, hs, device_handler):
def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
Expand All @@ -603,7 +616,9 @@ def __init__(self, hs, device_handler):
self._remote_edu_linearizer = Linearizer(name="remote_device_list")

# user_id -> list of updates waiting to be handled.
self._pending_updates = {}
self._pending_updates = (
{}
) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]

# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
Expand All @@ -626,7 +641,9 @@ def __init__(self, hs, device_handler):
)

@trace
async def incoming_device_list_update(self, origin, edu_content):
async def incoming_device_list_update(
self, origin: str, edu_content: JsonDict
) -> None:
"""Called on incoming device list update from federation. Responsible
for parsing the EDU and adding to pending updates list.
"""
Expand Down Expand Up @@ -687,7 +704,7 @@ async def incoming_device_list_update(self, origin, edu_content):
await self._handle_device_updates(user_id)

@measure_func("_incoming_device_list_update")
async def _handle_device_updates(self, user_id):
async def _handle_device_updates(self, user_id: str) -> None:
"Actually handle pending updates."

with (await self._remote_edu_linearizer.queue(user_id)):
Expand Down Expand Up @@ -735,7 +752,9 @@ async def _handle_device_updates(self, user_id):
stream_id for _, stream_id, _, _ in pending_updates
)

async def _need_to_do_resync(self, user_id, updates):
async def _need_to_do_resync(
self, user_id: str, updates: Iterable[Tuple[str, str, Iterable[str], JsonDict]]
) -> bool:
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
Expand Down Expand Up @@ -766,7 +785,7 @@ async def _need_to_do_resync(self, user_id, updates):
return False

@trace
async def _maybe_retry_device_resync(self):
async def _maybe_retry_device_resync(self) -> None:
"""Retry to resync device lists that are out of sync, except if another retry is
in progress.
"""
Expand Down Expand Up @@ -809,7 +828,7 @@ async def _maybe_retry_device_resync(self):

async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[dict]:
) -> Optional[JsonDict]:
"""Fetches all devices for a user and updates the device cache with them.

Args:
Expand All @@ -833,7 +852,7 @@ async def user_device_resync(
# it later.
await self.store.mark_remote_user_device_cache_as_stale(user_id)

return
return None
except (RequestSendFailed, HttpResponseException) as e:
logger.warning(
"Failed to handle device list update for %s: %s", user_id, e,
Expand All @@ -850,12 +869,12 @@ async def user_device_resync(
# next time we get a device list update for this user_id.
# This makes it more likely that the device lists will
# eventually become consistent.
return
return None
except FederationDeniedError as e:
set_tag("error", True)
log_kv({"reason": "FederationDeniedError"})
logger.info(e)
return
return None
except Exception as e:
set_tag("error", True)
log_kv(
Expand All @@ -868,7 +887,7 @@ async def user_device_resync(
# it later.
await self.store.mark_remote_user_device_cache_as_stale(user_id)

return
return None
log_kv({"result": result})
stream_id = result["stream_id"]
devices = result["devices"]
Expand Down Expand Up @@ -929,7 +948,7 @@ async def process_cross_signing_key_update(
user_id: str,
master_key: Optional[Dict[str, Any]],
self_signing_key: Optional[Dict[str, Any]],
) -> list:
) -> List[str]:
"""Process the given new master and self-signing key for the given remote user.

Args:
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)

async def store_device(
self, user_id: str, device_id: str, initial_device_display_name: str
self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
) -> bool:
"""Ensure the given device is known; add it to the store if not

Expand Down Expand Up @@ -1029,7 +1029,7 @@ async def update_device(
)

async def update_remote_device_list_cache_entry(
self, user_id: str, device_id: str, content: JsonDict, stream_id: int
self, user_id: str, device_id: str, content: JsonDict, stream_id: str
) -> None:
"""Updates a single device in the cache of a remote user's devicelist.

Expand Down Expand Up @@ -1057,7 +1057,7 @@ def _update_remote_device_list_cache_entry_txn(
user_id: str,
device_id: str,
content: JsonDict,
stream_id: int,
stream_id: str,
) -> None:
if content.get("deleted"):
self.db_pool.simple_delete_txn(
Expand Down