From be16ee59a87723c2da164f56dc2274ae3ac3e438 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 3 Sep 2020 22:02:29 +0100 Subject: [PATCH] Add type hints to more handlers (#8244) --- changelog.d/8244.misc | 1 + mypy.ini | 3 ++ synapse/handlers/events.py | 49 ++++++++++--------- synapse/handlers/initial_sync.py | 80 +++++++++++++++++++------------- synapse/handlers/pagination.py | 56 ++++++++++++---------- 5 files changed, 110 insertions(+), 79 deletions(-) create mode 100644 changelog.d/8244.misc diff --git a/changelog.d/8244.misc b/changelog.d/8244.misc new file mode 100644 index 000000000000..e650072223d5 --- /dev/null +++ b/changelog.d/8244.misc @@ -0,0 +1 @@ +Add type hints to pagination, initial sync and events handlers. diff --git a/mypy.ini b/mypy.ini index 8a351eabfebb..7764f178569d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,10 +17,13 @@ files = synapse/handlers/auth.py, synapse/handlers/cas_handler.py, synapse/handlers/directory.py, + synapse/handlers/events.py, synapse/handlers/federation.py, synapse/handlers/identity.py, + synapse/handlers/initial_sync.py, synapse/handlers/message.py, synapse/handlers/oidc_handler.py, + synapse/handlers/pagination.py, synapse/handlers/presence.py, synapse/handlers/room.py, synapse/handlers/room_member.py, diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 1924636c4d70..b05e32f45771 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -15,29 +15,30 @@ import logging import random +from typing import TYPE_CHECKING, Iterable, List, Optional from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state from synapse.logging.utils import log_function -from synapse.types import UserID +from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) class EventStreamHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super(EventStreamHandler, self).__init__(hs) - # Count of active streams per user - self._streams_per_user = {} - # Grace timers per user to delay the "stopped" signal - self._stop_timer_per_user = {} - self.distributor = hs.get_distributor() self.distributor.declare("started_user_eventstream") self.distributor.declare("stopped_user_eventstream") @@ -52,14 +53,14 @@ def __init__(self, hs): @log_function async def get_stream( self, - auth_user_id, - pagin_config, - timeout=0, - as_client_event=True, - affect_presence=True, - room_id=None, - is_guest=False, - ): + auth_user_id: str, + pagin_config: PaginationConfig, + timeout: int = 0, + as_client_event: bool = True, + affect_presence: bool = True, + room_id: Optional[str] = None, + is_guest: bool = False, + ) -> JsonDict: """Fetches the events stream for a given user. """ @@ -98,7 +99,7 @@ async def get_stream( # When the user joins a new room, or another user joins a currently # joined room, we need to send down presence for those users. - to_add = [] + to_add = [] # type: List[JsonDict] for event in events: if not isinstance(event, EventBase): continue @@ -110,7 +111,7 @@ async def get_stream( # Send down presence for everyone in the room. users = await self.state.get_current_users_in_room( event.room_id - ) + ) # type: Iterable[str] else: users = [event.state_key] @@ -144,20 +145,22 @@ async def get_stream( class EventHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super(EventHandler, self).__init__(hs) self.storage = hs.get_storage() - async def get_event(self, user, room_id, event_id): + async def get_event( + self, user: UserID, room_id: Optional[str], event_id: str + ) -> Optional[EventBase]: """Retrieve a single specified event. Args: - user (synapse.types.UserID): The user requesting the event - room_id (str|None): The expected room id. We'll return None if the + user: The user requesting the event + room_id: The expected room id. We'll return None if the event's room does not match. - event_id (str): The event ID to obtain. + event_id: The event ID to obtain. Returns: - dict: An event, or None if there is no event matching this ID. + An event, or None if there is no event matching this ID. Raises: SynapseError if there was a problem retrieving this event, or AuthError if the user does not have the rights to inspect this diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index ae6bd1d35271..d5ddc583ad69 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from twisted.internet import defer @@ -22,8 +23,9 @@ from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.storage.roommember import RoomsForUser from synapse.streams.config import PaginationConfig -from synapse.types import StreamToken, UserID +from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -31,11 +33,15 @@ from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) class InitialSyncHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super(InitialSyncHandler, self).__init__(hs) self.hs = hs self.state = hs.get_state_handler() @@ -48,27 +54,25 @@ def __init__(self, hs): def snapshot_all_rooms( self, - user_id=None, - pagin_config=None, - as_client_event=True, - include_archived=False, - ): + user_id: str, + pagin_config: PaginationConfig, + as_client_event: bool = True, + include_archived: bool = False, + ) -> JsonDict: """Retrieve a snapshot of all rooms the user is invited or has joined. This snapshot may include messages for all rooms where the user is joined, depending on the pagination config. Args: - user_id (str): The ID of the user making the request. - pagin_config (synapse.api.streams.PaginationConfig): The pagination - config used to determine how many messages *PER ROOM* to return. - as_client_event (bool): True to get events in client-server format. - include_archived (bool): True to get rooms that the user has left + user_id: The ID of the user making the request. + pagin_config: The pagination config used to determine how many + messages *PER ROOM* to return. + as_client_event: True to get events in client-server format. + include_archived: True to get rooms that the user has left Returns: - A list of dicts with "room_id" and "membership" keys for all rooms - the user is currently invited or joined in on. Rooms where the user - is joined on, may return a "messages" key with messages, depending - on the specified PaginationConfig. + A JsonDict with the same format as the response to `/intialSync` + API """ key = ( user_id, @@ -91,11 +95,11 @@ def snapshot_all_rooms( async def _snapshot_all_rooms( self, - user_id=None, - pagin_config=None, - as_client_event=True, - include_archived=False, - ): + user_id: str, + pagin_config: PaginationConfig, + as_client_event: bool = True, + include_archived: bool = False, + ) -> JsonDict: memberships = [Membership.INVITE, Membership.JOIN] if include_archived: @@ -134,7 +138,7 @@ async def _snapshot_all_rooms( if limit is None: limit = 10 - async def handle_room(event): + async def handle_room(event: RoomsForUser): d = { "room_id": event.room_id, "membership": event.membership, @@ -251,17 +255,18 @@ async def handle_room(event): return ret - async def room_initial_sync(self, requester, room_id, pagin_config=None): + async def room_initial_sync( + self, requester: Requester, room_id: str, pagin_config: PaginationConfig + ) -> JsonDict: """Capture the a snapshot of a room. If user is currently a member of the room this will be what is currently in the room. If the user left the room this will be what was in the room when they left. Args: - requester(Requester): The user to get a snapshot for. - room_id(str): The room to get a snapshot of. - pagin_config(synapse.streams.config.PaginationConfig): - The pagination config used to determine how many messages to - return. + requester: The user to get a snapshot for. + room_id: The room to get a snapshot of. + pagin_config: The pagination config used to determine how many + messages to return. Raises: AuthError if the user wasn't in the room. Returns: @@ -305,8 +310,14 @@ async def room_initial_sync(self, requester, room_id, pagin_config=None): return result async def _room_initial_sync_parted( - self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking - ): + self, + user_id: str, + room_id: str, + pagin_config: PaginationConfig, + membership: Membership, + member_event_id: str, + is_peeking: bool, + ) -> JsonDict: room_state = await self.state_store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -350,8 +361,13 @@ async def _room_initial_sync_parted( } async def _room_initial_sync_joined( - self, user_id, room_id, pagin_config, membership, is_peeking - ): + self, + user_id: str, + room_id: str, + pagin_config: PaginationConfig, + membership: Membership, + is_peeking: bool, + ) -> JsonDict: current_state = await self.state.get_current_state(room_id=room_id) # TODO: These concurrently diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5a1aa7d83086..63d7edff87a1 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Set from twisted.python.failure import Failure @@ -30,6 +30,10 @@ from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) @@ -68,7 +72,7 @@ class PaginationHandler(object): paginating during a purge. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -78,9 +82,9 @@ def __init__(self, hs): self._server_name = hs.hostname self.pagination_lock = ReadWriteLock() - self._purges_in_progress_by_room = set() + self._purges_in_progress_by_room = set() # type: Set[str] # map from purge id to PurgeStatus - self._purges_by_id = {} + self._purges_by_id = {} # type: Dict[str, PurgeStatus] self._event_serializer = hs.get_event_client_serializer() self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime @@ -102,7 +106,9 @@ def __init__(self, hs): job["longest_max_lifetime"], ) - async def purge_history_for_rooms_in_range(self, min_ms, max_ms): + async def purge_history_for_rooms_in_range( + self, min_ms: Optional[int], max_ms: Optional[int] + ): """Purge outdated events from rooms within the given retention range. If a default retention policy is defined in the server's configuration and its @@ -110,10 +116,10 @@ async def purge_history_for_rooms_in_range(self, min_ms, max_ms): retention policy. Args: - min_ms (int|None): Duration in milliseconds that define the lower limit of + min_ms: Duration in milliseconds that define the lower limit of the range to handle (exclusive). If None, it means that the range has no lower limit. - max_ms (int|None): Duration in milliseconds that define the upper limit of + max_ms: Duration in milliseconds that define the upper limit of the range to handle (inclusive). If None, it means that the range has no upper limit. """ @@ -220,18 +226,19 @@ async def purge_history_for_rooms_in_range(self, min_ms, max_ms): "_purge_history", self._purge_history, purge_id, room_id, token, True, ) - def start_purge_history(self, room_id, token, delete_local_events=False): + def start_purge_history( + self, room_id: str, token: str, delete_local_events: bool = False + ) -> str: """Start off a history purge on a room. Args: - room_id (str): The room to purge from - - token (str): topological token to delete events before - delete_local_events (bool): True to delete local events as well as + room_id: The room to purge from + token: topological token to delete events before + delete_local_events: True to delete local events as well as remote ones Returns: - str: unique ID for this purge transaction. + unique ID for this purge transaction. """ if room_id in self._purges_in_progress_by_room: raise SynapseError( @@ -284,14 +291,11 @@ def clear_purge(): self.hs.get_reactor().callLater(24 * 3600, clear_purge) - def get_purge_status(self, purge_id): + def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]: """Get the current status of an active purge Args: - purge_id (str): purge_id returned by start_purge_history - - Returns: - PurgeStatus|None + purge_id: purge_id returned by start_purge_history """ return self._purges_by_id.get(purge_id) @@ -312,8 +316,8 @@ async def purge_room(self, room_id: str) -> None: async def get_messages( self, requester: Requester, - room_id: Optional[str] = None, - pagin_config: Optional[PaginationConfig] = None, + room_id: str, + pagin_config: PaginationConfig, as_client_event: bool = True, event_filter: Optional[Filter] = None, ) -> Dict[str, Any]: @@ -368,11 +372,15 @@ async def get_messages( # If they have left the room then clamp the token to be before # they left the room, to save the effort of loading from the # database. + + # This is only None if the room is world_readable, in which + # case "JOIN" would have been returned. + assert member_event_id + leave_token = await self.store.get_topological_token_for_event( member_event_id ) - leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < max_topo: + if RoomStreamToken.parse(leave_token).topological < max_topo: source_config.from_key = str(leave_token) await self.hs.get_handlers().federation_handler.maybe_backfill( @@ -419,8 +427,8 @@ async def get_messages( ) if state_ids: - state = await self.store.get_events(list(state_ids.values())) - state = state.values() + state_dict = await self.store.get_events(list(state_ids.values())) + state = state_dict.values() time_now = self.clock.time_msec()