Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Type annotations in synapse.databases.main.devices #13025

Merged
merged 22 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13025.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type annotations to `synapse.storage.databases.main.devices`.
1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ exclude = (?x)
^(
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/schema/

|tests/api/test_auth.py
Expand Down
3 changes: 1 addition & 2 deletions synapse/replication/slave/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore

if TYPE_CHECKING:
from synapse.server import HomeServer


class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
def __init__(
self,
database: DatabasePool,
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(
self._min_stream_order_on_start = self.get_room_min_stream_ordering()

def get_device_stream_token(self) -> int:
# TODO: shouldn't this be moved to `DeviceWorkerStore`?
return self._device_list_id_gen.get_current_token()

async def get_users(self) -> List[JsonDict]:
Expand Down
51 changes: 33 additions & 18 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
cast,
)

from typing_extensions import Literal

from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
Expand All @@ -44,6 +46,8 @@
LoggingTransaction,
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.types import Cursor
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
Expand All @@ -65,7 +69,7 @@
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"


class DeviceWorkerStore(SQLBaseStore):
class DeviceWorkerStore(EndToEndKeyWorkerStore):
def __init__(
self,
database: DatabasePool,
Expand All @@ -74,7 +78,9 @@ def __init__(
):
super().__init__(database, db_conn, hs)

device_list_max = self._device_list_id_gen.get_current_token()
# Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_stream",
Expand Down Expand Up @@ -339,8 +345,9 @@ async def get_device_updates_by_remote(
# following this stream later.
last_processed_stream_id = from_stream_id

query_map = {}
cross_signing_keys_by_user = {}
# A map of (user ID, device ID) to (stream ID, context).
query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {}
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {}
for user_id, device_id, update_stream_id, update_context in updates:
# Calculate the remaining length budget.
# Note that, for now, each entry in `cross_signing_keys_by_user`
Expand Down Expand Up @@ -596,7 +603,7 @@ def _mark_as_sent_devices_by_remote_txn(
txn=txn,
table="device_lists_outbound_last_success",
key_names=("destination", "user_id"),
key_values=((destination, user_id) for user_id, _ in rows),
key_values=[(destination, user_id) for user_id, _ in rows],
value_names=("stream_id",),
value_values=((stream_id,) for _, stream_id in rows),
)
Expand All @@ -621,7 +628,9 @@ async def add_user_signature_change_to_streams(
The new stream ID.
"""

async with self._device_list_id_gen.get_next() as stream_id:
# TODO: this looks like it's _writing_. Should this be on DeviceStore rather
# than DeviceWorkerStore?
async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
Expand Down Expand Up @@ -686,7 +695,7 @@ async def get_user_devices_from_cache(
} - users_needing_resync
user_ids_not_in_cache = user_ids - user_ids_in_cache

results = {}
results: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
Expand Down Expand Up @@ -727,7 +736,7 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]
def get_cached_device_list_changes(
self,
from_key: int,
) -> Optional[Set[str]]:
) -> Optional[List[str]]:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
Expand All @@ -737,7 +746,7 @@ def get_cached_device_list_changes(
async def get_users_whose_devices_changed(
self,
from_key: int,
user_ids: Optional[Iterable[str]] = None,
user_ids: Optional[Collection[str]] = None,
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
Expand All @@ -757,6 +766,7 @@ async def get_users_whose_devices_changed(
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
user_ids_to_check: Optional[Collection[str]]
if user_ids is None:
# Get set of all users that have had device list changes since 'from_key'
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
Expand All @@ -772,7 +782,7 @@ async def get_users_whose_devices_changed(
return set()

def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
changes = set()
changes: Set[str] = set()

stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]
Expand All @@ -788,6 +798,9 @@ def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
"""

# Query device changes with a batch of users at a time
# Assertion for mypy's benefit; see also
# https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
assert user_ids_to_check is not None
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
Expand Down Expand Up @@ -854,7 +867,9 @@ async def get_all_device_list_changes_for_remotes(
if last_id == current_id:
return [], current_id, False

def _get_all_device_list_changes_for_remotes(txn):
def _get_all_device_list_changes_for_remotes(
txn: Cursor,
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
Expand Down Expand Up @@ -913,7 +928,7 @@ async def get_device_list_last_stream_id_for_remotes(
desc="get_device_list_last_stream_id_for_remotes",
)

results = {user_id: None for user_id in user_ids}
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
results.update({row["user_id"]: row["stream_id"] for row in rows})

return results
Expand Down Expand Up @@ -1346,9 +1361,9 @@ def __init__(

# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = LruCache(
cache_name="device_id_exists", max_size=10000
)
self.device_id_exists_cache: LruCache[
Tuple[str, str], Literal[True]
] = LruCache(cache_name="device_id_exists", max_size=10000)

async def store_device(
self,
Expand Down Expand Up @@ -1660,7 +1675,7 @@ def add_device_changes_txn(
context,
)

async with self._device_list_id_gen.get_next_mult(
async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
Expand Down Expand Up @@ -1713,7 +1728,7 @@ def _add_device_outbound_poke_to_stream_txn(
device_ids: Iterable[str],
hosts: Collection[str],
stream_ids: List[int],
context: Dict[str, str],
context: Optional[Dict[str, str]],
) -> None:
for host in hosts:
txn.call_after(
Expand Down Expand Up @@ -1884,7 +1899,7 @@ def add_device_list_outbound_pokes_txn(
[],
)

async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
Expand Down