Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints to synapse/storage/databases/main (#11984)
Browse files Browse the repository at this point in the history
  • Loading branch information
dklimpel authored Feb 21, 2022
1 parent 99f6d79 commit 7c82da2
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 53 deletions.
1 change: 1 addition & 0 deletions changelog.d/11984.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to storage classes.
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,11 @@ exclude = (?x)
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/presence.py
|synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py
|synapse/storage/databases/main/user_directory.py
|synapse/storage/schema/

|tests/api/test_auth.py
Expand Down
26 changes: 14 additions & 12 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,25 +204,27 @@ async def current_state_for_users(
Returns:
dict: `user_id` -> `UserPresenceState`
"""
states = {
user_id: self.user_to_current_state.get(user_id, None)
for user_id in user_ids
}
states = {}
missing = []
for user_id in user_ids:
state = self.user_to_current_state.get(user_id, None)
if state:
states[user_id] = state
else:
missing.append(user_id)

missing = [user_id for user_id, state in states.items() if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = await self.store.get_presence_for_users(missing)
states.update(res)

missing = [user_id for user_id, state in states.items() if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id) for user_id in missing
}
states.update(new)
self.user_to_current_state.update(new)
for user_id in missing:
# if user has no state in database, create the state
if not res.get(user_id, None):
new_state = UserPresenceState.default(user_id)
states[user_id] = new_state
self.user_to_current_state[user_id] = new_state

return states

Expand Down
61 changes: 41 additions & 20 deletions synapse/storage/databases/main/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast

from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
Expand All @@ -35,7 +43,7 @@ def __init__(
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
) -> None:
super().__init__(database, db_conn, hs)

# Used by `PresenceStore._get_active_presence()`
Expand All @@ -54,11 +62,14 @@ def __init__(
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
) -> None:
super().__init__(database, db_conn, hs)

self._instance_name = hs.get_instance_name()
self._presence_id_gen: AbstractStreamIdGenerator

self._can_persist_presence = (
hs.get_instance_name() in hs.config.worker.writers.presence
self._instance_name in hs.config.worker.writers.presence
)

if isinstance(database.engine, PostgresEngine):
Expand Down Expand Up @@ -109,7 +120,9 @@ async def update_presence(self, presence_states) -> Tuple[int, int]:

return stream_orderings[-1], self._presence_id_gen.get_current_token()

def _update_presence_txn(self, txn, stream_orderings, presence_states):
def _update_presence_txn(
self, txn: LoggingTransaction, stream_orderings, presence_states
) -> None:
for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after(
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
Expand Down Expand Up @@ -183,19 +196,23 @@ async def get_all_presence_updates(
if last_id == current_id:
return [], current_id, False

def get_all_presence_updates_txn(txn):
def get_all_presence_updates_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """
SELECT stream_id, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts,
status_msg,
currently_active
status_msg, currently_active
FROM presence_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], row[1:]) for row in txn]
updates = cast(
List[Tuple[int, list]],
[(row[0], row[1:]) for row in txn],
)

upper_bound = current_id
limited = False
Expand All @@ -210,15 +227,17 @@ def get_all_presence_updates_txn(txn):
)

@cached()
def _get_presence_for_user(self, user_id):
def _get_presence_for_user(self, user_id: str) -> None:
raise NotImplementedError()

@cachedList(
cached_method_name="_get_presence_for_user",
list_name="user_ids",
num_args=1,
)
async def get_presence_for_users(self, user_ids):
async def get_presence_for_users(
self, user_ids: Iterable[str]
) -> Dict[str, UserPresenceState]:
rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
Expand Down Expand Up @@ -257,7 +276,9 @@ async def should_user_receive_full_presence_with_token(
True if the user should have full presence sent to them, False otherwise.
"""

def _should_user_receive_full_presence_with_token_txn(txn):
def _should_user_receive_full_presence_with_token_txn(
txn: LoggingTransaction,
) -> bool:
sql = """
SELECT 1 FROM users_to_send_full_presence_to
WHERE user_id = ?
Expand All @@ -271,7 +292,7 @@ def _should_user_receive_full_presence_with_token_txn(txn):
_should_user_receive_full_presence_with_token_txn,
)

async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
"""Adds to the list of users who should receive a full snapshot of presence
upon their next sync.
Expand Down Expand Up @@ -353,10 +374,10 @@ async def get_presence_for_all_users(

return users_to_state

def get_current_presence_token(self):
def get_current_presence_token(self) -> int:
return self._presence_id_gen.get_current_token()

def _get_active_presence(self, db_conn: Connection):
def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
"""
Expand All @@ -379,12 +400,12 @@ def _get_active_presence(self, db_conn: Connection):

return [UserPresenceState(**row) for row in rows]

def take_presence_startup_info(self):
def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
self._presence_on_startup = []
return active_on_startup

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token)
for row in rows:
Expand Down
13 changes: 9 additions & 4 deletions synapse/storage/databases/main/purge_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

import logging
from typing import Any, List, Set, Tuple
from typing import Any, List, Set, Tuple, cast

from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken
Expand Down Expand Up @@ -55,7 +56,11 @@ async def purge_history(
)

def _purge_history_txn(
self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
self,
txn: LoggingTransaction,
room_id: str,
token: RoomStreamToken,
delete_local_events: bool,
) -> Set[int]:
# Tables that should be pruned:
# event_auth
Expand Down Expand Up @@ -273,7 +278,7 @@ def _purge_history_txn(
""",
(room_id,),
)
(min_depth,) = txn.fetchone()
(min_depth,) = cast(Tuple[int], txn.fetchone())

logger.info("[purge] updating room_depth to %d", min_depth)

Expand Down Expand Up @@ -318,7 +323,7 @@ async def purge_room(self, room_id: str) -> List[int]:
"purge_room", self._purge_room_txn, room_id
)

def _purge_room_txn(self, txn, room_id: str) -> List[int]:
def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
# First we fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
Expand Down
22 changes: 11 additions & 11 deletions synapse/storage/databases/main/user_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
) -> None:
super().__init__(database, db_conn, hs)

self.server_name = hs.hostname
Expand Down Expand Up @@ -234,10 +234,10 @@ def _get_next_batch(
processed_event_count = 0

for room_id, event_count in rooms_to_work_on:
is_in_room = await self.is_host_joined(room_id, self.server_name)
is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined]

if is_in_room:
users_with_profile = await self.get_users_in_room_with_profiles(room_id)
users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined]
# Throw away users excluded from the directory.
users_with_profile = {
user_id: profile
Expand Down Expand Up @@ -368,7 +368,7 @@ def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:

for user_id in users_to_work_on:
if await self.should_include_local_user_in_dir(user_id):
profile = await self.get_profileinfo(get_localpart_from_id(user_id))
profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined]
await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
Expand Down Expand Up @@ -397,25 +397,25 @@ async def should_include_local_user_in_dir(self, user: str) -> bool:
# technically it could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice sender can be
# contacted.
if self.get_app_service_by_user_id(user) is not None:
if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined]
return False

# We're opting to exclude appservice users (anyone matching the user
# namespace regex in the appservice registration) even though technically
# they could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice users can be
# contacted.
if self.get_if_app_services_interested_in_user(user):
if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined]
# TODO we might want to make this configurable for each app service
return False

# Support users are for diagnostics and should not appear in the user directory.
if await self.is_support_user(user):
if await self.is_support_user(user): # type: ignore[attr-defined]
return False

# Deactivated users aren't contactable, so should not appear in the user directory.
try:
if await self.get_user_deactivated_status(user):
if await self.get_user_deactivated_status(user): # type: ignore[attr-defined]
return False
except StoreError:
# No such user in the users table. No need to do this when calling
Expand All @@ -433,20 +433,20 @@ async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> boo
(EventTypes.RoomHistoryVisibility, ""),
)

current_state_ids = await self.get_filtered_current_state_ids(
current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
room_id, StateFilter.from_types(types_to_filter)
)

join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
if join_rules_id:
join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined]
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
return True

hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined]
if hist_vis_ev:
if (
hist_vis_ev.content.get("history_visibility")
Expand Down
6 changes: 3 additions & 3 deletions synapse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore
from synapse.storage.databases.main import DataStore, PurgeEventsStore

# Define a state map type from type/state_key to T (usually an event ID or
# event)
Expand Down Expand Up @@ -485,7 +485,7 @@ def __attrs_post_init__(self) -> None:
)

@classmethod
async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
Expand All @@ -502,7 +502,7 @@ async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
instance_id = int(key)
pos = int(value)

instance_name = await store.get_name_from_instance_id(instance_id)
instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined]
instance_map[instance_name] = pos

return cls(
Expand Down

0 comments on commit 7c82da2

Please sign in to comment.