Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into rav/remove_unused_…
Browse files Browse the repository at this point in the history
…mocks
  • Loading branch information
richvdh committed Dec 2, 2020
2 parents 92ce4a5 + ed51728 commit 269ba1b
Show file tree
Hide file tree
Showing 15 changed files with 367 additions and 111 deletions.
1 change: 1 addition & 0 deletions changelog.d/8858.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug on Synapse instances supporting Single-Sign-On, where users would be prompted to enter their password to confirm certain actions, even though they have not set a password.
1 change: 1 addition & 0 deletions changelog.d/8864.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove unused `FakeResponse` class from unit tests.
2 changes: 1 addition & 1 deletion synapse/federation/transport/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
Args:
hs (synapse.server.HomeServer): homeserver
resource (TransportLayerServer): resource class to register to
resource (JsonResource): resource class to register to
authenticator (Authenticator): authenticator to use
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
servlet_groups (list[str], optional): List of servlet groups to register.
Expand Down
58 changes: 43 additions & 15 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,7 @@ def __init__(self, hs: "HomeServer"):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
self._sso_enabled = (
hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
)
self._password_localdb_enabled = hs.config.password_localdb_enabled

# we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first
Expand All @@ -205,7 +203,7 @@ def __init__(self, hs: "HomeServer"):

# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = []
if hs.config.password_localdb_enabled:
if self._password_localdb_enabled:
login_types.append(LoginType.PASSWORD)

for provider in self.password_providers:
Expand All @@ -219,14 +217,6 @@ def __init__(self, hs: "HomeServer"):

self._supported_login_types = login_types

# Login types and UI Auth types have a heavy overlap, but are not
# necessarily identical. Login types have SSO (and other login types)
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
ui_auth_types = login_types.copy()
if self._sso_enabled:
ui_auth_types.append(LoginType.SSO)
self._supported_ui_auth_types = ui_auth_types

# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
Expand Down Expand Up @@ -339,7 +329,10 @@ async def validate_user_via_ui_auth(
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)

# build a list of supported flows
flows = [[login_type] for login_type in self._supported_ui_auth_types]
supported_ui_auth_types = await self._get_available_ui_auth_types(
requester.user
)
flows = [[login_type] for login_type in supported_ui_auth_types]

try:
result, params, session_id = await self.check_ui_auth(
Expand All @@ -351,7 +344,7 @@ async def validate_user_via_ui_auth(
raise

# find the completed login type
for login_type in self._supported_ui_auth_types:
for login_type in supported_ui_auth_types:
if login_type not in result:
continue

Expand All @@ -367,6 +360,41 @@ async def validate_user_via_ui_auth(

return params, session_id

async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
"""Get a list of the authentication types this user can use
"""

ui_auth_types = set()

# if the HS supports password auth, and the user has a non-null password, we
# support password auth
if self._password_localdb_enabled and self._password_enabled:
lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
if lookupres:
_, password_hash = lookupres
if password_hash:
ui_auth_types.add(LoginType.PASSWORD)

# also allow auth from password providers
for provider in self.password_providers:
for t in provider.get_supported_login_types().keys():
if t == LoginType.PASSWORD and not self._password_enabled:
continue
ui_auth_types.add(t)

# if sso is enabled, allow the user to log in via SSO iff they have a mapping
# from sso to mxid.
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
if await self.store.get_external_ids_by_user(user.to_string()):
ui_auth_types.add(LoginType.SSO)

# Our CAS impl does not (yet) correctly register users in user_external_ids,
# so always offer that if it's available.
if self.hs.config.cas.cas_enabled:
ui_auth_types.add(LoginType.SSO)

return ui_auth_types

def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
Expand Down Expand Up @@ -1029,7 +1057,7 @@ async def _validate_userid_login(
if result:
return result

if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True

# we've already checked that there is a (valid) password field
Expand Down
25 changes: 25 additions & 0 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,23 @@ async def get_user_by_external_id(
desc="get_user_by_external_id",
)

async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
"""Look up external ids for the given user
Args:
mxid: the MXID to be looked up
Returns:
Tuples of (auth_provider, external_id)
"""
res = await self.db_pool.simple_select_list(
table="user_external_ids",
keyvalues={"user_id": mxid},
retcols=("auth_provider", "external_id"),
desc="get_external_ids_by_user",
)
return [(r["auth_provider"], r["external_id"]) for r in res]

async def count_all_users(self):
"""Counts all users registered on the homeserver."""

Expand Down Expand Up @@ -963,6 +980,14 @@ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)

self.db_pool.updates.register_background_index_update(
"user_external_ids_user_id_idx",
index_name="user_external_ids_user_id_idx",
table="user_external_ids",
columns=["user_id"],
unique=False,
)

async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5825, 'user_external_ids_user_id_idx', '{}');
17 changes: 1 addition & 16 deletions tests/handlers/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,15 @@

from mock import Mock, patch

import attr
import pymacaroons

from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone

from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
from synapse.types import UserID

from tests.test_utils import FakeResponse
from tests.unittest import HomeserverTestCase, override_config


@attr.s
class FakeResponse:
code = attr.ib()
body = attr.ib()
phrase = attr.ib()

def deliverBody(self, protocol):
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))


# These are a few constants that are used as config parameters in the tests.
ISSUER = "https://issuer/"
CLIENT_ID = "test-client-id"
Expand Down
23 changes: 8 additions & 15 deletions tests/handlers/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@


import json
from typing import Dict

from mock import ANY, Mock, call

from twisted.internet import defer
from twisted.web.resource import Resource

from synapse.api.errors import AuthError
from synapse.federation.transport import server as federation_server
from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester
from synapse.util.ratelimitutils import FederationRateLimiter

from tests import unittest
from tests.test_utils import make_awaitable
Expand Down Expand Up @@ -53,20 +54,7 @@ def _make_edu_transaction_json(edu_type, content):
return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")


def register_federation_servlets(hs, resource):
federation_server.register_servlets(
hs,
resource=resource,
authenticator=federation_server.Authenticator(hs),
ratelimiter=FederationRateLimiter(
hs.get_clock(), config=hs.config.rc_federation
),
)


class TypingNotificationsTestCase(unittest.HomeserverTestCase):
servlets = [register_federation_servlets]

def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
Expand All @@ -89,6 +77,11 @@ def make_homeserver(self, reactor, clock):

return hs

def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_matrix/federation"] = TransportLayerServer(self.hs)
return d

def prepare(self, reactor, clock, hs):
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
Expand Down
14 changes: 8 additions & 6 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import attr

from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource

from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
Expand All @@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"

servlets = [
streams.register_servlets,
]

def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
Expand Down Expand Up @@ -88,6 +85,11 @@ def prepare(self, reactor, clock, hs):
self._client_transport = None
self._server_transport = None

def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_synapse/replication"] = ReplicationRestResource(self.hs)
return d

def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
Expand Down
Loading

0 comments on commit 269ba1b

Please sign in to comment.