Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix stack overflow in Keyring #5724

Merged
merged 3 commits into from
Jul 22, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 35 additions & 51 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,68 +238,65 @@ def _start_key_lookups(self, verify_requests):
"""

try:
# create a deferred for each server we're going to look up the keys
# for; we'll resolve them once we have completed our lookups.
# These will be passed into wait_for_previous_lookups to block
# any other lookups until we have finished.
# The deferreds are called with no logcontext.
server_to_deferred = {
rq.server_name: defer.Deferred() for rq in verify_requests
}

# We want to wait for any previous lookups to complete before
# proceeding.
yield self.wait_for_previous_lookups(server_to_deferred)
ctx = LoggingContext.current_context()

# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)

# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
#
# map from server name to a set of request ids
# map from server name to a set of outstanding request ids
server_to_request_ids = {}

for verify_request in verify_requests:
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)

def remove_deferreds(res, verify_request):
# Wait for any previous lookups to complete before proceeding.
yield self.wait_for_previous_lookups(server_to_request_ids.keys())

# take out a lock on each of the servers by sticking a Deferred in
# key_downloads
for server_name in server_to_request_ids.keys():
self.key_downloads[server_name] = defer.Deferred()
logger.debug("Got key lookup lock on %s", server_name)

# When we've finished fetching all the keys for a given server_name,
# drop the lock by resolving the deferred in key_downloads.
def lookup_done(res, verify_request):
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
server_requests = server_to_request_ids[server_name]
server_requests.remove(id(verify_request))

# if there are no more requests for this server, we can drop the lock.
if not server_requests:
with PreserveLoggingContext(ctx):
logger.debug("Releasing key lookup lock on %s", server_name)

d = self.key_downloads.pop(server_name)
d.callback(None)
return res

for verify_request in verify_requests:
verify_request.key_ready.addBoth(remove_deferreds, verify_request)
verify_request.key_ready.addBoth(lookup_done, verify_request)

# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)
except Exception:
logger.exception("Error starting key lookups")

@defer.inlineCallbacks
def wait_for_previous_lookups(self, server_to_deferred):
def wait_for_previous_lookups(self, server_names):
"""Waits for any previous key lookups for the given servers to finish.

Args:
server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
resolved once we've finished looking up keys for that server.
The Deferreds should be regular twisted ones which call their
callbacks with no logcontext.

Returns: a Deferred which resolves once all key lookups for the given
servers have completed. Follows the synapse rules of logcontext
preservation.
server_names (Iterable[str]): list of servers which we want to look up

Returns:
Deferred[None]: resolves once all key lookups for the given servers have
completed. Follows the synapse rules of logcontext preservation.
"""
loop_count = 1
while True:
wait_on = [
(server_name, self.key_downloads[server_name])
for server_name in server_to_deferred.keys()
for server_name in server_names
if server_name in self.key_downloads
]
if not wait_on:
Expand All @@ -314,19 +311,6 @@ def wait_for_previous_lookups(self, server_to_deferred):

loop_count += 1

ctx = LoggingContext.current_context()

def rm(r, server_name_):
with PreserveLoggingContext(ctx):
logger.debug("Releasing key lookup lock on %s", server_name_)
self.key_downloads.pop(server_name_, None)
return r

for server_name, deferred in server_to_deferred.items():
logger.debug("Got key lookup lock on %s", server_name)
self.key_downloads[server_name] = deferred
deferred.addBoth(rm, server_name)

def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request

Expand Down
29 changes: 0 additions & 29 deletions tests/crypto/test_keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,35 +86,6 @@ def check_context(self, _, expected):
getattr(LoggingContext.current_context(), "request", None), expected
)

def test_wait_for_previous_lookups(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there not be a new test case to ensure that we do actually take out a lock? Or is there already such a test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_verify_json_objects_for_server_awaits_previous_requests checks that we don't end up with two lookups at once.

kr = keyring.Keyring(self.hs)

lookup_1_deferred = defer.Deferred()
lookup_2_deferred = defer.Deferred()

# we run the lookup in a logcontext so that the patched inlineCallbacks can check
# it is doing the right thing with logcontexts.
wait_1_deferred = run_in_context(
kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
)

# there were no previous lookups, so the deferred should be ready
self.successResultOf(wait_1_deferred)

# set off another wait. It should block because the first lookup
# hasn't yet completed.
wait_2_deferred = run_in_context(
kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
)

self.assertFalse(wait_2_deferred.called)

# let the first lookup complete (in the sentinel context)
lookup_1_deferred.callback(None)

# now the second wait should complete.
self.successResultOf(wait_2_deferred)

def test_verify_json_objects_for_server_awaits_previous_requests(self):
key1 = signedjson.key.generate_signing_key(1)

Expand Down