Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert synapse.api to async/await #8031

Merged
merged 6 commits into from
Aug 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8031.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
123 changes: 56 additions & 67 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from typing import List, Optional, Tuple

import pymacaroons
from netaddr import IPAddress

from twisted.internet import defer
from twisted.web.server import Request

import synapse.types
Expand Down Expand Up @@ -80,28 +79,28 @@ def __init__(self, hs):
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key

@defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
auth_events_ids = yield self.compute_auth_events(
async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
):
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}

room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
event_auth.check(
room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
)

@defer.inlineCallbacks
def check_user_in_room(
async def check_user_in_room(
self,
room_id: str,
user_id: str,
current_state: Optional[StateMap[EventBase]] = None,
allow_departed_users: bool = False,
):
) -> EventBase:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
Expand All @@ -119,37 +118,35 @@ def check_user_in_room(
Raises:
AuthError if the user is/was not in the room.
Returns:
Deferred[Optional[EventBase]]:
Membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
Membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
"""
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
member = yield defer.ensureDeferred(
self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
member = await self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None

if membership == Membership.JOIN:
return member
if member:
membership = member.membership

# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = yield self.store.did_forget(user_id, room_id)
if not forgot:
if membership == Membership.JOIN:
return member

# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = await self.store.did_forget(user_id, room_id)
if not forgot:
return member

raise AuthError(403, "User %s not in room %s" % (user_id, room_id))

@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
async def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.is_host_joined(room_id, host)
latest_event_ids = await self.store.is_host_joined(room_id, host)
return latest_event_ids

def can_federate(self, event, auth_events):
Expand All @@ -160,14 +157,13 @@ def can_federate(self, event, auth_events):
def get_public_keys(self, invite_event):
return event_auth.get_public_keys(invite_event)

@defer.inlineCallbacks
def get_user_by_req(
async def get_user_by_req(
self,
request: Request,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
):
) -> synapse.types.Requester:
""" Get a registered user's ID.

Args:
Expand All @@ -180,7 +176,7 @@ def get_user_by_req(
/login will deliver access tokens regardless of expiration.

Returns:
defer.Deferred: resolves to a `synapse.types.Requester` object
Resolves to the requester
Raises:
InvalidClientCredentialsError if no user by that token exists or the token
is invalid.
Expand All @@ -194,14 +190,14 @@ def get_user_by_req(

access_token = self.get_access_token_from_request(request)

user_id, app_service = yield self._get_appservice_user_id(request)
user_id, app_service = await self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id)

if ip_addr and self._track_appservice_user_ips:
yield self.store.insert_client_ip(
await self.store.insert_client_ip(
user_id=user_id,
access_token=access_token,
ip=ip_addr,
Expand All @@ -211,7 +207,7 @@ def get_user_by_req(

return synapse.types.create_requester(user_id, app_service=app_service)

user_info = yield self.get_user_by_access_token(
user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired
)
user = user_info["user"]
Expand All @@ -221,7 +217,7 @@ def get_user_by_req(
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
if (
expiration_ts is not None
and self.clock.time_msec() >= expiration_ts
Expand All @@ -235,7 +231,7 @@ def get_user_by_req(
device_id = user_info.get("device_id")

if user and access_token and ip_addr:
yield self.store.insert_client_ip(
await self.store.insert_client_ip(
user_id=user.to_string(),
access_token=access_token,
ip=ip_addr,
Expand All @@ -261,8 +257,7 @@ def get_user_by_req(
except KeyError:
raise MissingClientTokenError()

@defer.inlineCallbacks
def _get_appservice_user_id(self, request):
async def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
Expand All @@ -283,14 +278,13 @@ def _get_appservice_user_id(self, request):

if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.")
if not (yield self.store.get_user_by_id(user_id)):
if not (await self.store.get_user_by_id(user_id)):
raise AuthError(403, "Application service has not registered this user")
return user_id, app_service

@defer.inlineCallbacks
def get_user_by_access_token(
async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
):
) -> dict:
""" Validate access token and get user_id from it

Args:
Expand All @@ -300,7 +294,7 @@ def get_user_by_access_token(
allow_expired: If False, raises an InvalidClientTokenError
if the token is expired
Returns:
Deferred[dict]: dict that includes:
dict that includes:
`user` (UserID)
`is_guest` (bool)
`token_id` (int|None): access token id. May be None if guest
Expand All @@ -314,7 +308,7 @@ def get_user_by_access_token(

if rights == "access":
# first look in the database
r = yield self._look_up_user_by_access_token(token)
r = await self._look_up_user_by_access_token(token)
if r:
valid_until_ms = r["valid_until_ms"]
if (
Expand Down Expand Up @@ -352,7 +346,7 @@ def get_user_by_access_token(
# It would of course be much easier to store guest access
# tokens in the database as well, but that would break existing
# guest tokens.
stored_user = yield self.store.get_user_by_id(user_id)
stored_user = await self.store.get_user_by_id(user_id)
if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]:
Expand Down Expand Up @@ -482,9 +476,8 @@ def _verify_expiry(self, caveat):
now = self.hs.get_clock().time_msec()
return now < expiry

@defer.inlineCallbacks
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
async def _look_up_user_by_access_token(self, token):
ret = await self.store.get_user_by_access_token(token)
if not ret:
return None

Expand All @@ -507,7 +500,7 @@ def get_appservice_by_req(self, request):
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
request.authenticated_entity = service.sender
return defer.succeed(service)
return service

async def is_server_admin(self, user: UserID) -> bool:
""" Check if the given user is a local server admin.
Expand All @@ -522,19 +515,19 @@ async def is_server_admin(self, user: UserID) -> bool:

def compute_auth_events(
self, event, current_state_ids: StateMap[str], for_verification: bool = False,
):
) -> List[str]:
"""Given an event and current state return the list of event IDs used
to auth an event.

If `for_verification` is False then only return auth events that
should be added to the event's `auth_events`.

Returns:
defer.Deferred(list[str]): List of event IDs.
List of event IDs.
"""

if event.type == EventTypes.Create:
return defer.succeed([])
return []

# Currently we ignore the `for_verification` flag even though there are
# some situations where we can drop particular auth events when adding
Expand All @@ -553,7 +546,7 @@ def compute_auth_events(
if auth_ev_id:
auth_ids.append(auth_ev_id)

return defer.succeed(auth_ids)
return auth_ids

async def check_can_change_room_list(self, room_id: str, user: UserID):
"""Determine whether the user is allowed to edit the room's entry in the
Expand Down Expand Up @@ -636,10 +629,9 @@ def get_access_token_from_request(request: Request):

return query_params[0].decode("ascii")

@defer.inlineCallbacks
def check_user_in_room_or_world_readable(
async def check_user_in_room_or_world_readable(
self, room_id: str, user_id: str, allow_departed_users: bool = False
):
) -> Tuple[str, Optional[str]]:
"""Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised.

Expand All @@ -650,10 +642,9 @@ def check_user_in_room_or_world_readable(
members but have now departed

Returns:
Deferred[tuple[str, str|None]]: Resolves to the current membership of
the user in the room and the membership event ID of the user. If
the user is not in the room and never has been, then
`(Membership.JOIN, None)` is returned.
Resolves to the current membership of the user in the room and the
membership event ID of the user. If the user is not in the room and
never has been, then `(Membership.JOIN, None)` is returned.
"""

try:
Expand All @@ -662,15 +653,13 @@ def check_user_in_room_or_world_readable(
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.check_user_in_room(
member_event = await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users
)
return member_event.membership, member_event.event_id
except AuthError:
visibility = yield defer.ensureDeferred(
self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
visibility = await self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility
Expand Down
13 changes: 5 additions & 8 deletions synapse/api/auth_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import logging

from twisted.internet import defer

from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
Expand All @@ -36,8 +34,7 @@ def __init__(self, hs):
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids

@defer.inlineCallbacks
def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag

Expand All @@ -60,7 +57,7 @@ def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
if user_id is not None:
if user_id == self._server_notices_mxid:
return
if (yield self.store.is_support_user(user_id)):
if await self.store.is_support_user(user_id):
return

if self._hs_disabled:
Expand All @@ -76,11 +73,11 @@ def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):

# If the user is already part of the MAU cohort or a trial user
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
timestamp = await self.store.user_last_seen_monthly_active(user_id)
if timestamp:
return

is_trial = yield self.store.is_trial_user(user_id)
is_trial = await self.store.is_trial_user(user_id)
if is_trial:
return
elif threepid:
Expand All @@ -93,7 +90,7 @@ def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
# allow registration. Support users are excluded from MAU checks.
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
current_mau = await self.store.get_monthly_active_count()
if current_mau >= self._max_mau_value:
raise ResourceLimitError(
403,
Expand Down
Loading