diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index eea64c1c9f84..27c8939c1480 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -15,11 +15,13 @@ # limitations under the License. import logging from collections import namedtuple +from typing import Collection, List import six from twisted.internet import defer -from twisted.internet.defer import DeferredList +from twisted.internet.defer import Deferred, DeferredList +from twisted.python.failure import Failure from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError @@ -29,6 +31,7 @@ RoomVersion, ) from synapse.crypto.event_signing import check_event_content_hash +from synapse.crypto.keyring import Keyring from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.http.servlet import assert_params_in_dict @@ -56,7 +59,12 @@ def __init__(self, hs): @defer.inlineCallbacks def _check_sigs_and_hash_and_fetch( - self, origin, pdus, room_version, outlier=False, include_none=False + self, + origin: str, + pdus: List[EventBase], + room_version: str, + outlier: bool = False, + include_none: bool = False, ): """Takes a list of PDUs and checks the signatures and hashs of each one. If a PDU fails its signature check then we check if we have it in @@ -69,11 +77,11 @@ def _check_sigs_and_hash_and_fetch( a new list. Args: - origin (str) - pdu (list) - room_version (str) - outlier (bool): Whether the events are outliers or not - include_none (str): Whether to include None in the returned list + origin + pdu + room_version + outlier: Whether the events are outliers or not + include_none: Whether to include None in the returned list for events that have failed their checks Returns: @@ -82,7 +90,7 @@ def _check_sigs_and_hash_and_fetch( deferreds = self._check_sigs_and_hashes(room_version, pdus) @defer.inlineCallbacks - def handle_check_result(pdu, deferred): + def handle_check_result(pdu: EventBase, deferred: Deferred): try: res = yield make_deferred_yieldable(deferred) except SynapseError: @@ -125,21 +133,23 @@ def handle_check_result(pdu, deferred): else: return [p for p in valid_pdus if p] - def _check_sigs_and_hash(self, room_version, pdu): + def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred: return make_deferred_yieldable( self._check_sigs_and_hashes(room_version, [pdu])[0] ) - def _check_sigs_and_hashes(self, room_version, pdus): + def _check_sigs_and_hashes( + self, room_version: str, pdus: List[EventBase] + ) -> List[Deferred]: """Checks that each of the received events is correctly signed by the sending server. Args: - room_version (str): The room version of the PDUs - pdus (list[FrozenEvent]): the events to be checked + room_version: The room version of the PDUs + pdus: the events to be checked Returns: - list[Deferred]: for each input event, a deferred which: + For each input event, a deferred which: * returns the original event if the checks pass * returns a redacted version of the event (if the signature matched but the hash did not) @@ -150,7 +160,7 @@ def _check_sigs_and_hashes(self, room_version, pdus): ctx = LoggingContext.current_context() - def callback(_, pdu): + def callback(_, pdu: EventBase): with PreserveLoggingContext(ctx): if not check_event_content_hash(pdu): # let's try to distinguish between failures because the event was @@ -187,7 +197,7 @@ def callback(_, pdu): return pdu - def errback(failure, pdu): + def errback(failure: Failure, pdu: EventBase): failure.trap(SynapseError) with PreserveLoggingContext(ctx): logger.warning( @@ -213,16 +223,18 @@ class PduToCheckSig( pass -def _check_sigs_on_pdus(keyring, room_version, pdus): +def _check_sigs_on_pdus( + keyring: Keyring, room_version: str, pdus: Collection[EventBase] +) -> List[Deferred]: """Check that the given events are correctly signed Args: - keyring (synapse.crypto.Keyring): keyring object to do the checks - room_version (str): the room version of the PDUs - pdus (Collection[EventBase]): the events to be checked + keyring: keyring object to do the checks + room_version: the room version of the PDUs + pdus: the events to be checked Returns: - List[Deferred]: a Deferred for each event in pdus, which will either succeed if + A Deferred for each event in pdus, which will either succeed if the signatures are valid, or fail (with a SynapseError) if not. """ @@ -327,7 +339,7 @@ def event_err(e, pdu_to_check): return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check] -def _flatten_deferred_list(deferreds): +def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred: """Given a list of deferreds, either return the single deferred, combine into a DeferredList, or return an already resolved deferred. """ @@ -339,7 +351,7 @@ def _flatten_deferred_list(deferreds): return defer.succeed(None) -def _is_invite_via_3pid(event): +def _is_invite_via_3pid(event: EventBase) -> bool: return ( event.type == EventTypes.Member and event.membership == Membership.INVITE diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 4870e3965248..b5538bc07a56 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -187,7 +187,7 @@ def claim_client_keys(self, destination, content, timeout): async def backfill( self, dest: str, room_id: str, limit: int, extremities: Iterable[str] - ) -> List[EventBase]: + ) -> Optional[List[EventBase]]: """Requests some more historic PDUs for the given room from the given destination server. @@ -199,9 +199,9 @@ async def backfill( """ logger.debug("backfill extrem=%s", extremities) - # If there are no extremeties then we've (probably) reached the start. + # If there are no extremities then we've (probably) reached the start. if not extremities: - return + return None transaction_data = await self.transport_layer.backfill( dest, room_id, extremities, limit @@ -284,7 +284,7 @@ async def get_pdu( pdu_list = [ event_from_pdu_json(p, room_version, outlier=outlier) for p in transaction_data["pdus"] - ] + ] # type: List[EventBase] if pdu_list and pdu_list[0]: pdu = pdu_list[0] @@ -615,7 +615,7 @@ async def send_request(destination) -> Dict[str, Any]: ] if auth_chain_create_events != [create_event.event_id]: raise InvalidResponseError( - "Unexpected create event(s) in auth chain" + "Unexpected create event(s) in auth chain: %s" % (auth_chain_create_events,) ) diff --git a/tox.ini b/tox.ini index b715ea0bff40..21aec24eb4e8 100644 --- a/tox.ini +++ b/tox.ini @@ -181,6 +181,8 @@ commands = mypy \ synapse/api \ synapse/config/ \ synapse/events/spamcheck.py \ + synapse/federation/federation_base.py \ + synapse/federation/federation_client.py \ synapse/federation/sender \ synapse/federation/transport \ synapse/handlers/sync.py \