Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to synapse/storage/databases/main #11984

Merged
merged 7 commits into from
Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11984.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to storage classes.
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,11 @@ exclude = (?x)
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/presence.py
|synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py
|synapse/storage/databases/main/user_directory.py
|synapse/storage/schema/

|tests/api/test_auth.py
Expand Down
26 changes: 14 additions & 12 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,25 +204,27 @@ async def current_state_for_users(
Returns:
dict: `user_id` -> `UserPresenceState`
"""
states = {
user_id: self.user_to_current_state.get(user_id, None)
for user_id in user_ids
}
states = {}
missing = []
for user_id in user_ids:
state = self.user_to_current_state.get(user_id, None)
if state:
states[user_id] = state
else:
missing.append(user_id)

missing = [user_id for user_id, state in states.items() if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = await self.store.get_presence_for_users(missing)
states.update(res)

missing = [user_id for user_id, state in states.items() if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id) for user_id in missing
}
states.update(new)
self.user_to_current_state.update(new)
for user_id in missing:
# if user has no state in database, create the state
if not res.get(user_id, None):
dklimpel marked this conversation as resolved.
Show resolved Hide resolved
new = {user_id: UserPresenceState.default(user_id)}
states.update(new)
self.user_to_current_state.update(new)
dklimpel marked this conversation as resolved.
Show resolved Hide resolved

return states

Expand Down
61 changes: 41 additions & 20 deletions synapse/storage/databases/main/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast

from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
Expand All @@ -35,7 +43,7 @@ def __init__(
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
) -> None:
super().__init__(database, db_conn, hs)

# Used by `PresenceStore._get_active_presence()`
Expand All @@ -54,11 +62,14 @@ def __init__(
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
) -> None:
super().__init__(database, db_conn, hs)

self._instance_name = hs.get_instance_name()
self._presence_id_gen: AbstractStreamIdGenerator

self._can_persist_presence = (
hs.get_instance_name() in hs.config.worker.writers.presence
self._instance_name in hs.config.worker.writers.presence
)

if isinstance(database.engine, PostgresEngine):
Expand Down Expand Up @@ -109,7 +120,9 @@ async def update_presence(self, presence_states) -> Tuple[int, int]:

return stream_orderings[-1], self._presence_id_gen.get_current_token()

def _update_presence_txn(self, txn, stream_orderings, presence_states):
def _update_presence_txn(
self, txn: LoggingTransaction, stream_orderings, presence_states
) -> None:
for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after(
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
Expand Down Expand Up @@ -183,19 +196,23 @@ async def get_all_presence_updates(
if last_id == current_id:
return [], current_id, False

def get_all_presence_updates_txn(txn):
def get_all_presence_updates_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """
SELECT stream_id, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts,
status_msg,
currently_active
status_msg, currently_active
FROM presence_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], row[1:]) for row in txn]
updates = cast(
List[Tuple[int, list]],
[(row[0], row[1:]) for row in txn],
)

upper_bound = current_id
limited = False
Expand All @@ -210,15 +227,17 @@ def get_all_presence_updates_txn(txn):
)

@cached()
def _get_presence_for_user(self, user_id):
def _get_presence_for_user(self, user_id: str) -> None:
raise NotImplementedError()

@cachedList(
cached_method_name="_get_presence_for_user",
list_name="user_ids",
num_args=1,
)
async def get_presence_for_users(self, user_ids):
async def get_presence_for_users(
self, user_ids: Iterable[str]
) -> Dict[str, UserPresenceState]:
rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
Expand Down Expand Up @@ -257,7 +276,9 @@ async def should_user_receive_full_presence_with_token(
True if the user should have full presence sent to them, False otherwise.
"""

def _should_user_receive_full_presence_with_token_txn(txn):
def _should_user_receive_full_presence_with_token_txn(
txn: LoggingTransaction,
) -> bool:
sql = """
SELECT 1 FROM users_to_send_full_presence_to
WHERE user_id = ?
Expand All @@ -271,7 +292,7 @@ def _should_user_receive_full_presence_with_token_txn(txn):
_should_user_receive_full_presence_with_token_txn,
)

async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
"""Adds to the list of users who should receive a full snapshot of presence
upon their next sync.

Expand Down Expand Up @@ -353,10 +374,10 @@ async def get_presence_for_all_users(

return users_to_state

def get_current_presence_token(self):
def get_current_presence_token(self) -> int:
return self._presence_id_gen.get_current_token()

def _get_active_presence(self, db_conn: Connection):
def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
"""
Expand All @@ -379,12 +400,12 @@ def _get_active_presence(self, db_conn: Connection):

return [UserPresenceState(**row) for row in rows]

def take_presence_startup_info(self):
def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
self._presence_on_startup = []
return active_on_startup

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token)
for row in rows:
Expand Down
13 changes: 9 additions & 4 deletions synapse/storage/databases/main/purge_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

import logging
from typing import Any, List, Set, Tuple
from typing import Any, List, Set, Tuple, cast

from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken
Expand Down Expand Up @@ -55,7 +56,11 @@ async def purge_history(
)

def _purge_history_txn(
self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
self,
txn: LoggingTransaction,
room_id: str,
token: RoomStreamToken,
delete_local_events: bool,
) -> Set[int]:
# Tables that should be pruned:
# event_auth
Expand Down Expand Up @@ -273,7 +278,7 @@ def _purge_history_txn(
""",
(room_id,),
)
(min_depth,) = txn.fetchone()
(min_depth,) = cast(Tuple[int], txn.fetchone())

logger.info("[purge] updating room_depth to %d", min_depth)

Expand Down Expand Up @@ -318,7 +323,7 @@ async def purge_room(self, room_id: str) -> List[int]:
"purge_room", self._purge_room_txn, room_id
)

def _purge_room_txn(self, txn, room_id: str) -> List[int]:
def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
# First we fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
Expand Down
22 changes: 11 additions & 11 deletions synapse/storage/databases/main/user_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
) -> None:
super().__init__(database, db_conn, hs)

self.server_name = hs.hostname
Expand Down Expand Up @@ -234,10 +234,10 @@ def _get_next_batch(
processed_event_count = 0

for room_id, event_count in rooms_to_work_on:
is_in_room = await self.is_host_joined(room_id, self.server_name)
is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined]
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

if is_in_room:
users_with_profile = await self.get_users_in_room_with_profiles(room_id)
users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined]
# Throw away users excluded from the directory.
users_with_profile = {
user_id: profile
Expand Down Expand Up @@ -368,7 +368,7 @@ def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:

for user_id in users_to_work_on:
if await self.should_include_local_user_in_dir(user_id):
profile = await self.get_profileinfo(get_localpart_from_id(user_id))
profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined]
await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
Expand Down Expand Up @@ -397,25 +397,25 @@ async def should_include_local_user_in_dir(self, user: str) -> bool:
# technically it could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice sender can be
# contacted.
if self.get_app_service_by_user_id(user) is not None:
if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined]
return False

# We're opting to exclude appservice users (anyone matching the user
# namespace regex in the appservice registration) even though technically
# they could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice users can be
# contacted.
if self.get_if_app_services_interested_in_user(user):
if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined]
# TODO we might want to make this configurable for each app service
return False

# Support users are for diagnostics and should not appear in the user directory.
if await self.is_support_user(user):
if await self.is_support_user(user): # type: ignore[attr-defined]
return False

# Deactivated users aren't contactable, so should not appear in the user directory.
try:
if await self.get_user_deactivated_status(user):
if await self.get_user_deactivated_status(user): # type: ignore[attr-defined]
return False
except StoreError:
# No such user in the users table. No need to do this when calling
Expand All @@ -433,20 +433,20 @@ async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> boo
(EventTypes.RoomHistoryVisibility, ""),
)

current_state_ids = await self.get_filtered_current_state_ids(
current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
room_id, StateFilter.from_types(types_to_filter)
)

join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
if join_rules_id:
join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined]
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
return True

hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined]
if hist_vis_ev:
if (
hist_vis_ev.content.get("history_visibility")
Expand Down
6 changes: 3 additions & 3 deletions synapse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore
from synapse.storage.databases.main import DataStore, PurgeEventsStore

# Define a state map type from type/state_key to T (usually an event ID or
# event)
Expand Down Expand Up @@ -485,7 +485,7 @@ def __attrs_post_init__(self) -> None:
)

@classmethod
async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
Expand All @@ -502,7 +502,7 @@ async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
instance_id = int(key)
pos = int(value)

instance_name = await store.get_name_from_instance_id(instance_id)
instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined]
instance_map[instance_name] = pos

return cls(
Expand Down