Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add Keyring.verify_events_for_server and reduce memory usage #10018

Merged
merged 5 commits into from
May 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10018.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Reduce memory usage when verifying signatures on large numbers of events at once.
98 changes: 88 additions & 10 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging
import urllib
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple

import attr
from signedjson.key import (
Expand All @@ -42,6 +42,8 @@
SynapseError,
)
from synapse.config.key import TrustedKeyServer
from synapse.events import EventBase
from synapse.events.utils import prune_event_dict
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
Expand Down Expand Up @@ -69,7 +71,11 @@ class VerifyJsonRequest:
Attributes:
server_name: The name of the server to verify against.

json_object: The JSON object to verify.
get_json_object: A callback to fetch the JSON object to verify.
A callback is used to allow deferring the creation of the JSON
object to verify until needed, e.g. for events we can defer
creating the redacted copy. This reduces the memory usage when
there are large numbers of in flight requests.

minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
Expand All @@ -88,14 +94,50 @@ class VerifyJsonRequest:
"""

server_name = attr.ib(type=str)
json_object = attr.ib(type=JsonDict)
get_json_object = attr.ib(type=Callable[[], JsonDict])
minimum_valid_until_ts = attr.ib(type=int)
request_name = attr.ib(type=str)
key_ids = attr.ib(init=False, type=List[str])
key_ids = attr.ib(type=List[str])
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)

def __attrs_post_init__(self):
self.key_ids = signature_ids(self.json_object, self.server_name)
@staticmethod
def from_json_object(
server_name: str,
json_object: JsonDict,
minimum_valid_until_ms: int,
request_name: str,
):
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON
object for the given server.
"""
key_ids = signature_ids(json_object, server_name)
return VerifyJsonRequest(
server_name,
lambda: json_object,
minimum_valid_until_ms,
request_name=request_name,
key_ids=key_ids,
)

@staticmethod
def from_event(
server_name: str,
event: EventBase,
minimum_valid_until_ms: int,
):
"""Create a VerifyJsonRequest to verify all signatures on an event
object for the given server.
"""
key_ids = list(event.signatures.get(server_name, []))
return VerifyJsonRequest(
server_name,
# We defer creating the redacted json object, as it uses a lot more
# memory than the Event object itself.
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
minimum_valid_until_ms,
request_name=event.event_id,
key_ids=key_ids,
)


class KeyLookupError(ValueError):
Expand Down Expand Up @@ -147,8 +189,13 @@ def verify_json_for_server(
Deferred[None]: completes if the the object was correctly signed, otherwise
errbacks with an error
"""
req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
requests = (req,)
request = VerifyJsonRequest.from_json_object(
server_name,
json_object,
validity_time,
request_name,
)
requests = (request,)
return make_deferred_yieldable(self._verify_objects(requests)[0])

def verify_json_objects_for_server(
Expand All @@ -175,10 +222,41 @@ def verify_json_objects_for_server(
logcontext.
"""
return self._verify_objects(
VerifyJsonRequest(server_name, json_object, validity_time, request_name)
VerifyJsonRequest.from_json_object(
server_name, json_object, validity_time, request_name
)
for server_name, json_object, validity_time, request_name in server_and_json
)

def verify_events_for_server(
self, server_and_events: Iterable[Tuple[str, EventBase, int]]
) -> List[defer.Deferred]:
"""Bulk verification of signatures on events.

Args:
server_and_events:
Iterable of `(server_name, event, validity_time)` tuples.

`server_name` is which server we are verifying the signature for
on the event.

`event` is the event that we'll verify the signatures of for
the given `server_name`.

`validity_time` is a timestamp at which the signing key must be
valid.

Returns:
List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each event's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
"""
return self._verify_objects(
VerifyJsonRequest.from_event(server_name, event, validity_time)
for server_name, event, validity_time in server_and_events
)

def _verify_objects(
self, verify_requests: Iterable[VerifyJsonRequest]
) -> List[defer.Deferred]:
Expand Down Expand Up @@ -892,7 +970,7 @@ async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
with PreserveLoggingContext():
_, key_id, verify_key = await verify_request.key_ready

json_object = verify_request.json_object
json_object = verify_request.get_json_object()

try:
verify_signed_json(json_object, server_name, verify_key)
Expand Down
17 changes: 5 additions & 12 deletions synapse/federation/federation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,7 @@ def errback(failure: Failure, pdu: EventBase):
return deferreds


class PduToCheckSig(
namedtuple(
"PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
)
):
class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
pass


Expand Down Expand Up @@ -184,7 +180,6 @@ def _check_sigs_on_pdus(
pdus_to_check = [
PduToCheckSig(
pdu=p,
redacted_pdu_json=prune_event(p).get_pdu_json(),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
Expand All @@ -195,13 +190,12 @@ def _check_sigs_on_pdus(
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]

more_deferreds = keyring.verify_json_objects_for_server(
more_deferreds = keyring.verify_events_for_server(
[
(
p.sender_domain,
p.redacted_pdu_json,
p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_sender
]
Expand Down Expand Up @@ -230,13 +224,12 @@ def sender_err(e, pdu_to_check):
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]

more_deferreds = keyring.verify_json_objects_for_server(
more_deferreds = keyring.verify_events_for_server(
[
(
get_domain_from_id(p.pdu.event_id),
p.redacted_pdu_json,
p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_event_id
]
Expand Down