Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Factor out a wrapper for password auth providers
Browse files Browse the repository at this point in the history
  • Loading branch information
richvdh committed Dec 1, 2020
1 parent 89f7930 commit 22fe5b7
Showing 1 changed file with 127 additions and 48 deletions.
175 changes: 127 additions & 48 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
Expand Down Expand Up @@ -181,17 +182,12 @@ def __init__(self, hs: "HomeServer"):
# better way to break the loop
account_handler = ModuleApi(hs, self)

self.password_providers = []
for module, config in hs.config.password_providers:
try:
self.password_providers.append(
module(config=config, account_handler=account_handler)
)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
self.password_providers = [
PasswordProvider.load(module, config, account_handler)
for module, config in hs.config.password_providers
]

logger.info("Extra password_providers: %r", self.password_providers)
logger.info("Extra password_providers: %s", self.password_providers)

self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
Expand Down Expand Up @@ -858,6 +854,9 @@ async def validate_login(
qualified_user_id = UserID(username, self.hs.hostname).to_string()

login_type = login_submission.get("type")
if not isinstance(login_type, str):
raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)

known_login_type = False

# special case to check for "password" for the check_password interface
Expand All @@ -871,18 +870,17 @@ async def validate_login(
raise SynapseError(400, "Missing parameter: password")

for provider in self.password_providers:
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
if (
hasattr(provider._pp, "check_password")
and login_type == LoginType.PASSWORD
):
known_login_type = True
is_valid = await provider.check_password(qualified_user_id, password)
is_valid = await provider._pp.check_password(
qualified_user_id, password
)
if is_valid:
return qualified_user_id, None

if not hasattr(provider, "get_supported_login_types") or not hasattr(
provider, "check_auth"
):
# this password provider doesn't understand custom login types
continue

supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
Expand All @@ -907,8 +905,6 @@ async def validate_login(

result = await provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
return result

if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
Expand Down Expand Up @@ -946,19 +942,9 @@ async def check_password_provider_3pid(
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
if hasattr(provider, "check_3pid_auth"):
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await provider.check_3pid_auth(medium, address, password)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
result = (result, None)
return result
result = await provider.check_3pid_auth(medium, address, password)
if result:
return result

return None, None

Expand Down Expand Up @@ -1016,16 +1002,11 @@ async def delete_access_token(self, access_token: str):

# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
# This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)
if inspect.isawaitable(result):
await result
await provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)

# delete pushers associated with this access token
if user_info.token_id is not None:
Expand Down Expand Up @@ -1054,11 +1035,10 @@ async def delete_access_tokens_for_user(

# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices:
await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
for token, token_id, device_id in tokens_and_devices:
await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)

# delete pushers associated with the access tokens
await self.hs.get_pusherpool().remove_pushers_by_access_token(
Expand Down Expand Up @@ -1382,3 +1362,102 @@ def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon


class PasswordProvider:
"""Wrapper for a password auth provider module
This class abstracts out all of the backwards-compatibility hacks for
password providers, to provide a consistent interface.
"""

@classmethod
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
return cls(pp)

def __init__(self, pp):
self._pp = pp

def __str__(self):
return str(self._pp)

def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
Returns a map from a login type identifier (such as m.login.password) to an
iterable giving the fields which must be provided by the user in the submission
to the /login API.
"""
g = getattr(self._pp, "get_supported_login_types", None)
if not g:
return {}
return g()

async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
"""Check if the user has presented valid login credentials
Args:
username: user id presented by the client. Either an MXID or an unqualified
username.
login_type: the login type being attempted - one of the types returned by
get_supported_login_types()
login_dict: the dictionary of login secrets passed by the client.
Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
user, and `callback` is an optional callback which will be called with the
result from the /login call (including access_token, device_id, etc.)
"""
g = getattr(self._pp, "check_auth", None)
if not g:
return None
result = await g(username, login_type, login_dict)

# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None

return result

async def check_3pid_auth(
self, medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
g = getattr(self._pp, "check_3pid_auth", None)
if not g:
return None

# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await g(medium, address, password)

# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None

return result

async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str
) -> None:
g = getattr(self._pp, "on_logged_out", None)
if not g:
return

# This might return an awaitable, if it does block the log out
# until it completes.
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
if inspect.isawaitable(result):
await result

0 comments on commit 22fe5b7

Please sign in to comment.