Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Implement cancellation support/protection for module callbacks #12568

Merged
merged 4 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12568.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Protect module callbacks with read semantics against cancellation.
12 changes: 9 additions & 3 deletions synapse/events/presence_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@

from typing_extensions import ParamSpec

from twisted.internet.defer import CancelledError

from synapse.api.presence import UserPresenceState
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -158,7 +160,9 @@ async def get_users_for_states(
try:
# Note: result is an object here, because we don't trust modules to
# return the types they're supposed to.
result: object = await callback(state_updates)
result: object = await delay_cancellation(callback(state_updates))
except CancelledError:
raise
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
Expand Down Expand Up @@ -210,7 +214,9 @@ async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
# run all the callbacks for get_interested_users and combine the results
for callback in self._get_interested_users_callbacks:
try:
result = await callback(user_id)
result = await delay_cancellation(callback(user_id))
except CancelledError:
raise
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
Expand Down
36 changes: 25 additions & 11 deletions synapse/events/spamcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserProfile
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable

if TYPE_CHECKING:
import synapse.events
Expand Down Expand Up @@ -255,7 +255,7 @@ async def check_event_for_spam(
will be used as the error message returned to the user.
"""
for callback in self._check_event_for_spam_callbacks:
res: Union[bool, str] = await callback(event)
res: Union[bool, str] = await delay_cancellation(callback(event))
if res:
return res

Expand All @@ -276,7 +276,10 @@ async def user_may_join_room(
Whether the user may join the room
"""
for callback in self._user_may_join_room_callbacks:
if await callback(user_id, room_id, is_invited) is False:
may_join_room = await delay_cancellation(
callback(user_id, room_id, is_invited)
)
if may_join_room is False:
return False

return True
Expand All @@ -297,7 +300,10 @@ async def user_may_invite(
True if the user may send an invite, otherwise False
"""
for callback in self._user_may_invite_callbacks:
if await callback(inviter_userid, invitee_userid, room_id) is False:
may_invite = await delay_cancellation(
callback(inviter_userid, invitee_userid, room_id)
)
if may_invite is False:
return False

return True
Expand All @@ -322,7 +328,10 @@ async def user_may_send_3pid_invite(
True if the user may send the invite, otherwise False
"""
for callback in self._user_may_send_3pid_invite_callbacks:
if await callback(inviter_userid, medium, address, room_id) is False:
may_send_3pid_invite = await delay_cancellation(
callback(inviter_userid, medium, address, room_id)
)
if may_send_3pid_invite is False:
return False

return True
Expand All @@ -339,7 +348,8 @@ async def user_may_create_room(self, userid: str) -> bool:
True if the user may create a room, otherwise False
"""
for callback in self._user_may_create_room_callbacks:
if await callback(userid) is False:
may_create_room = await delay_cancellation(callback(userid))
if may_create_room is False:
return False

return True
Expand All @@ -359,7 +369,10 @@ async def user_may_create_room_alias(
True if the user may create a room alias, otherwise False
"""
for callback in self._user_may_create_room_alias_callbacks:
if await callback(userid, room_alias) is False:
may_create_room_alias = await delay_cancellation(
callback(userid, room_alias)
)
if may_create_room_alias is False:
return False

return True
Expand All @@ -377,7 +390,8 @@ async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
True if the user may publish the room, otherwise False
"""
for callback in self._user_may_publish_room_callbacks:
if await callback(userid, room_id) is False:
may_publish_room = await delay_cancellation(callback(userid, room_id))
if may_publish_room is False:
return False

return True
Expand All @@ -400,7 +414,7 @@ async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
for callback in self._check_username_for_spam_callbacks:
# Make a copy of the user profile object to ensure the spam checker cannot
# modify it.
if await callback(user_profile.copy()):
if await delay_cancellation(callback(user_profile.copy())):
return True

return False
Expand Down Expand Up @@ -428,7 +442,7 @@ async def check_registration_for_spam(
"""

for callback in self._check_registration_for_spam_callbacks:
behaviour = await (
behaviour = await delay_cancellation(
callback(email_threepid, username, request_info, auth_provider_id)
)
assert isinstance(behaviour, RegistrationBehaviour)
Expand Down Expand Up @@ -472,7 +486,7 @@ async def check_media_file_for_spam(
"""

for callback in self._check_media_file_for_spam_callbacks:
spam = await callback(file_wrapper, file_info)
spam = await delay_cancellation(callback(file_wrapper, file_info))
if spam:
return True

Expand Down
36 changes: 30 additions & 6 deletions synapse/events/third_party_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
import logging
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple

from twisted.internet.defer import CancelledError

from synapse.api.errors import ModuleFailedException, SynapseError
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.storage.roommember import ProfileInfo
from synapse.types import Requester, StateMap
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -263,7 +265,11 @@ async def check_event_allowed(

for callback in self._check_event_allowed_callbacks:
try:
res, replacement_data = await callback(event, state_events)
res, replacement_data = await delay_cancellation(
callback(event, state_events)
)
except CancelledError:
raise
except SynapseError as e:
# FIXME: Being able to throw SynapseErrors is relied upon by
# some modules. PR #10386 accidentally broke this ability.
Expand Down Expand Up @@ -333,8 +339,13 @@ async def check_threepid_can_be_invited(

for callback in self._check_threepid_can_be_invited_callbacks:
try:
if await callback(medium, address, state_events) is False:
threepid_can_be_invited = await delay_cancellation(
callback(medium, address, state_events)
)
if threepid_can_be_invited is False:
return False
except CancelledError:
raise
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)

Expand All @@ -361,8 +372,13 @@ async def check_visibility_can_be_modified(

for callback in self._check_visibility_can_be_modified_callbacks:
try:
if await callback(room_id, state_events, new_visibility) is False:
visibility_can_be_modified = await delay_cancellation(
callback(room_id, state_events, new_visibility)
)
if visibility_can_be_modified is False:
return False
except CancelledError:
raise
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)

Expand Down Expand Up @@ -400,8 +416,11 @@ async def check_can_shutdown_room(self, user_id: str, room_id: str) -> bool:
"""
for callback in self._check_can_shutdown_room_callbacks:
try:
if await callback(user_id, room_id) is False:
can_shutdown_room = await delay_cancellation(callback(user_id, room_id))
if can_shutdown_room is False:
return False
except CancelledError:
raise
except Exception as e:
logger.exception(
"Failed to run module API callback %s: %s", callback, e
Expand All @@ -422,8 +441,13 @@ async def check_can_deactivate_user(
"""
for callback in self._check_can_deactivate_user_callbacks:
try:
if await callback(user_id, by_admin) is False:
can_deactivate_user = await delay_cancellation(
callback(user_id, by_admin)
)
if can_deactivate_user is False:
return False
except CancelledError:
raise
except Exception as e:
logger.exception(
"Failed to run module API callback %s: %s", callback, e
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/account_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
from synapse.util.async_helpers import delay_cancellation

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -150,7 +151,7 @@ async def is_user_expired(self, user_id: str) -> bool:
Whether the user has expired.
"""
for callback in self._is_user_expired_callbacks:
expired = await callback(user_id)
expired = await delay_cancellation(callback(user_id))
if expired is not None:
return expired

Expand Down
25 changes: 19 additions & 6 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import unpaddedbase64
from pymacaroons.exceptions import MacaroonVerificationFailedException

from twisted.internet.defer import CancelledError
from twisted.web.server import Request

from synapse.api.constants import LoginType
Expand All @@ -67,7 +68,7 @@
from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import base62_encode
Expand Down Expand Up @@ -2202,7 +2203,11 @@ async def check_auth(
# other than None (i.e. until a callback returns a success)
for callback in self.auth_checker_callbacks[login_type]:
try:
result = await callback(username, login_type, login_dict)
result = await delay_cancellation(
callback(username, login_type, login_dict)
)
except CancelledError:
raise
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
Expand Down Expand Up @@ -2263,7 +2268,9 @@ async def check_3pid_auth(

for callback in self.check_3pid_auth_callbacks:
try:
result = await callback(medium, address, password)
result = await delay_cancellation(callback(medium, address, password))
except CancelledError:
raise
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
Expand Down Expand Up @@ -2345,7 +2352,7 @@ async def get_username_for_registration(
"""
for callback in self.get_username_for_registration_callbacks:
try:
res = await callback(uia_results, params)
res = await delay_cancellation(callback(uia_results, params))

if isinstance(res, str):
return res
Expand All @@ -2359,6 +2366,8 @@ async def get_username_for_registration(
callback,
res,
)
except CancelledError:
raise
except Exception as e:
logger.error(
"Module raised an exception in get_username_for_registration: %s",
Expand Down Expand Up @@ -2388,7 +2397,7 @@ async def get_displayname_for_registration(
"""
for callback in self.get_displayname_for_registration_callbacks:
try:
res = await callback(uia_results, params)
res = await delay_cancellation(callback(uia_results, params))

if isinstance(res, str):
return res
Expand All @@ -2402,6 +2411,8 @@ async def get_displayname_for_registration(
callback,
res,
)
except CancelledError:
raise
except Exception as e:
logger.error(
"Module raised an exception in get_displayname_for_registration: %s",
Expand Down Expand Up @@ -2429,7 +2440,7 @@ async def is_3pid_allowed(
"""
for callback in self.is_3pid_allowed_callbacks:
try:
res = await callback(medium, address, registration)
res = await delay_cancellation(callback(medium, address, registration))

if res is False:
return res
Expand All @@ -2443,6 +2454,8 @@ async def is_3pid_allowed(
callback,
res,
)
except CancelledError:
raise
except Exception as e:
logger.error("Module raised an exception in is_3pid_allowed: %s", e)
raise SynapseError(code=500, msg="Internal Server Error")
Expand Down