Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add types to auth and auth_blocking. #9876

Merged
merged 4 commits into from
Apr 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9876.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `synapse.api.auth` and `synapse.api.auth_blocking` modules.
78 changes: 39 additions & 39 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import pymacaroons
from netaddr import IPAddress

from twisted.web.server import Request

import synapse.types
from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
Expand All @@ -36,11 +35,14 @@
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.types import Requester, StateMap, UserID, create_requester
from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -68,7 +70,7 @@ class Auth:
The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
Expand All @@ -88,13 +90,13 @@ def __init__(self, hs):

async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
):
) -> None:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
auth_events_by_id = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}

room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
event_auth.check(
Expand Down Expand Up @@ -151,17 +153,11 @@ async def check_user_in_room(

raise AuthError(403, "User %s not in room %s" % (user_id, room_id))

async def check_host_in_room(self, room_id, host):
async def check_host_in_room(self, room_id: str, host: str) -> bool:
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = await self.store.is_host_joined(room_id, host)
return latest_event_ids

def can_federate(self, event, auth_events):
creation_event = auth_events.get((EventTypes.Create, ""))
return await self.store.is_host_joined(room_id, host)

return creation_event.content.get("m.federate", True) is True

def get_public_keys(self, invite_event):
def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
return event_auth.get_public_keys(invite_event)

async def get_user_by_req(
Expand All @@ -170,7 +166,7 @@ async def get_user_by_req(
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
) -> synapse.types.Requester:
) -> Requester:
"""Get a registered user's ID.

Args:
Expand All @@ -196,7 +192,7 @@ async def get_user_by_req(
access_token = self.get_access_token_from_request(request)

user_id, app_service = await self._get_appservice_user_id(request)
if user_id:
if user_id and app_service:
if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip(
user_id=user_id,
Expand All @@ -206,9 +202,7 @@ async def get_user_by_req(
device_id="dummy-device", # stubbed
)

requester = synapse.types.create_requester(
user_id, app_service=app_service
)
requester = create_requester(user_id, app_service=app_service)

request.requester = user_id
opentracing.set_tag("authenticated_entity", user_id)
Expand Down Expand Up @@ -251,7 +245,7 @@ async def get_user_by_req(
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)

requester = synapse.types.create_requester(
requester = create_requester(
user_info.user_id,
token_id,
is_guest,
Expand All @@ -271,7 +265,9 @@ async def get_user_by_req(
except KeyError:
raise MissingClientTokenError()

async def _get_appservice_user_id(self, request):
async def _get_appservice_user_id(
self, request: Request
) -> Tuple[Optional[str], Optional[ApplicationService]]:
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
Expand All @@ -283,6 +279,9 @@ async def _get_appservice_user_id(self, request):
if ip_address not in app_service.ip_range_whitelist:
return None, None

# This will always be set by the time Twisted calls us.
assert request.args is not None

if b"user_id" not in request.args:
return app_service.sender, app_service

Expand Down Expand Up @@ -387,7 +386,9 @@ async def get_user_by_access_token(
logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
raise InvalidClientTokenError("Invalid macaroon passed.")

def _parse_and_validate_macaroon(self, token, rights="access"):
def _parse_and_validate_macaroon(
self, token: str, rights: str = "access"
) -> Tuple[str, bool]:
"""Takes a macaroon and tries to parse and validate it. This is cached
if and only if rights == access and there isn't an expiry.

Expand Down Expand Up @@ -432,15 +433,16 @@ def _parse_and_validate_macaroon(self, token, rights="access"):

return user_id, guest

def validate_macaroon(self, macaroon, type_string, user_id):
def validate_macaroon(
self, macaroon: pymacaroons.Macaroon, type_string: str, user_id: str
) -> None:
"""
validate that a Macaroon is understood by and was signed by this server.

Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token required (e.g. "access",
"delete_pusher")
user_id (str): The user_id required
macaroon: The macaroon to validate
type_string: The kind of token required (e.g. "access", "delete_pusher")
user_id: The user_id required
"""
v = pymacaroons.Verifier()

Expand All @@ -465,9 +467,7 @@ def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
if not service:
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
request.requester = synapse.types.create_requester(
service.sender, app_service=service
)
request.requester = create_requester(service.sender, app_service=service)
return service

async def is_server_admin(self, user: UserID) -> bool:
Expand Down Expand Up @@ -519,7 +519,7 @@ def compute_auth_events(

return auth_ids

async def check_can_change_room_list(self, room_id: str, user: UserID):
async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
"""Determine whether the user is allowed to edit the room's entry in the
published room list.

Expand Down Expand Up @@ -554,11 +554,11 @@ async def check_can_change_room_list(self, room_id: str, user: UserID):
return user_level >= send_level

@staticmethod
def has_access_token(request: Request):
def has_access_token(request: Request) -> bool:
"""Checks if the request has an access_token.

Returns:
bool: False if no access_token was given, True otherwise.
False if no access_token was given, True otherwise.
"""
# This will always be set by the time Twisted calls us.
assert request.args is not None
Expand All @@ -568,13 +568,13 @@ def has_access_token(request: Request):
return bool(query_params) or bool(auth_headers)

@staticmethod
def get_access_token_from_request(request: Request):
def get_access_token_from_request(request: Request) -> str:
"""Extracts the access_token from the request.

Args:
request: The http request.
Returns:
unicode: The access_token
The access_token
Raises:
MissingClientTokenError: If there isn't a single access_token in the
request
Expand Down Expand Up @@ -649,5 +649,5 @@ async def check_user_in_room_or_world_readable(
% (user_id, room_id),
)

def check_auth_blocking(self, *args, **kwargs):
return self._auth_blocking.check_auth_blocking(*args, **kwargs)
async def check_auth_blocking(self, *args, **kwargs) -> None:
await self._auth_blocking.check_auth_blocking(*args, **kwargs)
9 changes: 6 additions & 3 deletions synapse/api/auth_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@
# limitations under the License.

import logging
from typing import Optional
from typing import TYPE_CHECKING, Optional

from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
from synapse.types import Requester

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class AuthBlocking:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()

self._server_notices_mxid = hs.config.server_notices_mxid
Expand All @@ -43,7 +46,7 @@ async def check_auth_blocking(
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
):
) -> None:
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag

Expand Down
4 changes: 2 additions & 2 deletions synapse/event_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import logging
from typing import List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple

from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
Expand Down Expand Up @@ -688,7 +688,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase
return False


def get_public_keys(invite_event):
def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
public_keys = []
if "public_key" in invite_event.content:
o = {"public_key": invite_event.content["public_key"]}
Expand Down