Skip to content

Commit

Permalink
Pass module API to OIDC mapping provider (#16974)
Browse files Browse the repository at this point in the history
As done for SAML mapping provider, let's pass the module API to the OIDC
one so the mapper can do more logic in its code.
  • Loading branch information
MatMaul authored Mar 19, 2024
1 parent 05489d8 commit 74ab329
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
1 change: 1 addition & 0 deletions changelog.d/16974.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
As done for SAML mapping provider, let's pass the module API to the OIDC one so the mapper can do more logic in its code.
4 changes: 3 additions & 1 deletion docs/sso_mapping_providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ comment these options out and use those specified by the module instead.

A custom mapping provider must specify the following methods:

* `def __init__(self, parsed_config)`
* `def __init__(self, parsed_config, module_api)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
- `module_api` - a `synapse.module_api.ModuleApi` object which provides the
stable API available for extension modules.
* `def parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
Expand Down
17 changes: 14 additions & 3 deletions synapse/handlers/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
Expand Down Expand Up @@ -421,9 +422,19 @@ def __init__(
# from the IdP's jwks_uri, if required.
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)

self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config
user_mapping_provider_init_method = (
provider.user_mapping_provider_class.__init__
)
if len(inspect.signature(user_mapping_provider_init_method).parameters) == 3:
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config,
ModuleApi(hs, hs.get_auth_handler()),
)
else:
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config,
)

self._skip_verification = provider.skip_verification
self._allow_existing_users = provider.allow_existing_users

Expand Down Expand Up @@ -1583,7 +1594,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
This is the default mapping provider.
"""

def __init__(self, config: JinjaOidcMappingConfig):
def __init__(self, config: JinjaOidcMappingConfig, module_api: ModuleApi):
self._config = config

@staticmethod
Expand Down

0 comments on commit 74ab329

Please sign in to comment.