Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

JWT OIDC secrets for Sign in with Apple #9549

Merged
merged 8 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion docs/sample_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1780,7 +1780,26 @@ saml2_config:
#
# client_id: Required. oauth2 client id to use.
#
# client_secret: Required. oauth2 client secret to use.
# client_secret: oauth2 client secret to use. May be omitted if
# client_secret_jwt_key is given, or if client_auth_method is 'none'.
#
# client_secret_jwt_key: Alternative to client_secret: details of a key used
# to create a JSON Web Token to be used as an OAuth2 client secret. If
# given, must be a dictionary with the following properties:
#
# key: a pem-encoded signing key. Must be a suitable key for the
# algorithm specified. Required unless 'key_file' is given.
#
# key_file: the path to file containing a pem-encoded signing key file.
# Required unless 'key' is given.
#
# jwt_header: a dictionary giving properties to include in the JWT
# header. Must include the key 'alg', giving the algorithm used to
# sign the JWT, such as "ES256", using the JWA identifiers in
# RFC7518.
#
# jwt_payload: an optional dictionary giving properties to include in
# the JWT payload. Normally this should include an 'iss' key.
#
# client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and
Expand Down
38 changes: 34 additions & 4 deletions synapse/config/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,8 @@ def ensure_directory(cls, dir_path):

@classmethod
def read_file(cls, file_path, config_name):
cls.check_file(file_path, config_name)
with open(file_path) as file_stream:
return file_stream.read()
"""Deprecated: call read_file directly"""
return read_file(file_path, (config_name,))

def read_template(self, filename: str) -> jinja2.Template:
"""Load a template file from disk.
Expand Down Expand Up @@ -894,4 +893,35 @@ def get_instance(self, key: str) -> str:
return self._get_instance(key)


__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
def read_file(file_path: Any, config_path: Iterable[str]) -> str:
"""Check the given file exists, and read it into a string

If it does not, emit an error indicating the problem

Args:
file_path: the file to be read
config_path: where in the configuration file_path came from, so that a useful
error can be emitted if it does not exist.
Returns:
content of the file.
Raises:
ConfigError if there is a problem reading the file.
"""
if not isinstance(file_path, str):
raise ConfigError("%r is not a string", config_path)

try:
os.stat(file_path)
with open(file_path) as file_stream:
return file_stream.read()
except OSError as e:
raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e


__all__ = [
"Config",
"RootConfig",
"ShardedWorkerHandlingConfig",
"RoutableShardedWorkerHandlingConfig",
"read_file",
]
2 changes: 2 additions & 0 deletions synapse/config/_base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,5 @@ class ShardedWorkerHandlingConfig:

class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
def get_instance(self, key: str) -> str: ...

def read_file(file_path: Any, config_path: Iterable[str]) -> str: ...
89 changes: 82 additions & 7 deletions synapse/config/oidc_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

from collections import Counter
from typing import Iterable, Optional, Tuple, Type
from typing import Iterable, Mapping, Optional, Tuple, Type

import attr

Expand All @@ -25,7 +25,7 @@
from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_mxc_uri

from ._base import Config, ConfigError
from ._base import Config, ConfigError, read_file

DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"

Expand Down Expand Up @@ -97,7 +97,26 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
#
# client_id: Required. oauth2 client id to use.
#
# client_secret: Required. oauth2 client secret to use.
# client_secret: oauth2 client secret to use. May be omitted if
# client_secret_jwt_key is given, or if client_auth_method is 'none'.
#
# client_secret_jwt_key: Alternative to client_secret: details of a key used
# to create a JSON Web Token to be used as an OAuth2 client secret. If
# given, must be a dictionary with the following properties:
#
# key: a pem-encoded signing key. Must be a suitable key for the
# algorithm specified. Required unless 'key_file' is given.
#
# key_file: the path to file containing a pem-encoded signing key file.
# Required unless 'key' is given.
#
# jwt_header: a dictionary giving properties to include in the JWT
# header. Must include the key 'alg', giving the algorithm used to
# sign the JWT, such as "ES256", using the JWA identifiers in
# RFC7518.
#
# jwt_payload: an optional dictionary giving properties to include in
# the JWT payload. Normally this should include an 'iss' key.
#
# client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and
Expand Down Expand Up @@ -240,7 +259,7 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
# jsonschema definition of the configuration settings for an oidc identity provider
OIDC_PROVIDER_CONFIG_SCHEMA = {
"type": "object",
"required": ["issuer", "client_id", "client_secret"],
"required": ["issuer", "client_id"],
"properties": {
"idp_id": {
"type": "string",
Expand All @@ -262,6 +281,30 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
"issuer": {"type": "string"},
"client_id": {"type": "string"},
"client_secret": {"type": "string"},
"client_secret_jwt_key": {
"type": "object",
"required": ["jwt_header"],
"oneOf": [
{"required": ["key"]},
{"required": ["key_file"]},
],
"properties": {
"key": {"type": "string"},
"key_file": {"type": "string"},
"jwt_header": {
"type": "object",
"required": ["alg"],
"properties": {
"alg": {"type": "string"},
},
"additionalProperties": {"type": "string"},
},
"jwt_payload": {
"type": "object",
"additionalProperties": {"type": "string"},
},
},
},
"client_auth_method": {
"type": "string",
# the following list is the same as the keys of
Expand Down Expand Up @@ -404,6 +447,20 @@ def _parse_oidc_config_dict(
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
) from e

client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey]
if client_secret_jwt_key_config is not None:
keyfile = client_secret_jwt_key_config.get("key_file")
if keyfile:
key = read_file(keyfile, config_path + ("client_secret_jwt_key",))
else:
key = client_secret_jwt_key_config["key"]
client_secret_jwt_key = OidcProviderClientSecretJwtKey(
key=key,
jwt_header=client_secret_jwt_key_config["jwt_header"],
jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}),
)

return OidcProviderConfig(
idp_id=idp_id,
idp_name=oidc_config.get("idp_name", "OIDC"),
Expand All @@ -412,7 +469,8 @@ def _parse_oidc_config_dict(
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
client_secret=oidc_config["client_secret"],
client_secret=oidc_config.get("client_secret"),
client_secret_jwt_key=client_secret_jwt_key,
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
scopes=oidc_config.get("scopes", ["openid"]),
authorization_endpoint=oidc_config.get("authorization_endpoint"),
Expand All @@ -427,6 +485,18 @@ def _parse_oidc_config_dict(
)


@attr.s(slots=True, frozen=True)
class OidcProviderClientSecretJwtKey:
# a pem-encoded signing key
key = attr.ib(type=str)

# properties to include in the JWT header
jwt_header = attr.ib(type=Mapping[str, str])

# properties to include in the JWT payload.
jwt_payload = attr.ib(type=Mapping[str, str])


@attr.s(slots=True, frozen=True)
class OidcProviderConfig:
# a unique identifier for this identity provider. Used in the 'user_external_ids'
Expand All @@ -452,8 +522,13 @@ class OidcProviderConfig:
# oauth2 client id to use
client_id = attr.ib(type=str)

# oauth2 client secret to use
client_secret = attr.ib(type=str)
# oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
# a secret.
client_secret = attr.ib(type=Optional[str])

# key to use to construct a JWT to use as a client secret. May be `None` if
# `client_secret` is set.
client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])

# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and
Expand Down
101 changes: 96 additions & 5 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,13 +15,13 @@
# limitations under the License.
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode

import attr
import pymacaroons
from authlib.common.security import generate_token
from authlib.jose import JsonWebToken
from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
Expand All @@ -35,12 +36,15 @@
from twisted.web.client import readBody

from synapse.config import ConfigError
from synapse.config.oidc_config import OidcProviderConfig
from synapse.config.oidc_config import (
OidcProviderClientSecretJwtKey,
OidcProviderConfig,
)
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry

Expand Down Expand Up @@ -276,9 +280,21 @@ def __init__(

self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method

client_secret = None # type: Union[None, str, JwtClientSecret]
if provider.client_secret:
client_secret = provider.client_secret
elif provider.client_secret_jwt_key:
client_secret = JwtClientSecret(
provider.client_secret_jwt_key,
provider.client_id,
provider.issuer,
hs.get_clock(),
)

self._client_auth = ClientAuth(
provider.client_id,
provider.client_secret,
client_secret,
provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = provider.client_auth_method
Expand Down Expand Up @@ -977,6 +993,81 @@ def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
return str(remote_user_id)


# number of seconds a newly-generated client secret should be valid for
CLIENT_SECRET_VALIDITY_SECONDS = 3600

# minimum remaining validity on a client secret before we should generate a new one
CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600


class JwtClientSecret:
"""A class which generates a new client secret on demand, based on a JWK

This implementation is designed to comply with the requirements for Apple Sign in:
https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048

It looks like those requirements are based on https://tools.ietf.org/html/rfc7523,
but it's worth noting that we still put the generated secret in the "client_secret"
field (or rather, whereever client_auth_method puts it) rather than in a
client_assertion field in the body as that RFC seems to require.
"""

def __init__(
self,
key: OidcProviderClientSecretJwtKey,
oauth_client_id: str,
oauth_issuer: str,
clock: Clock,
):
self._key = key
self._oauth_client_id = oauth_client_id
self._oauth_issuer = oauth_issuer
self._clock = clock
self._cached_secret = b""
self._cached_secret_replacement_time = 0
Comment on lines +1026 to +1027
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure I understand what would use the cached values here -- does ClientAuth cache the secret internally?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no it does not. _get_secret returns the cached value.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my question was unclear -- it looked to me like ClientAuth was only used once and then discarded, but I see that a single instance is re-used while exchanging the code.

I guess I'm surprised that caching is necessary -- is it very inefficient to calculate the secret each time? Or are you just not supposed to do that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm surprised that caching is necessary -- is it very inefficient to calculate the secret each time? Or are you just not supposed to do that?

I cached it because calculating a new one involves calculating a cryptographic signature, which generally isn't a trivial thing to do, and a cache seemed easy and low impact. I haven't measured it, tbh.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough! Thanks for the explanation.


def __str__(self):
# if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
# encode_client_secret_basic, which calls "{}".format(secret), which ends up
# here.
return self._get_secret().decode("ascii")

def __bytes__(self):
# if client_auth_method is client_secret_post, then ClientAuth.prepare calls
# encode_client_secret_post, which ends up here.
return self._get_secret()

def _get_secret(self) -> bytes:
now = self._clock.time()

# if we have enough validity on our existing secret, use it
if now < self._cached_secret_replacement_time:
return self._cached_secret

issued_at = int(now)
expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS

# we copy the configured header because jwt.encode modifies it.
header = dict(self._key.jwt_header)

# see https://tools.ietf.org/html/rfc7523#section-3
payload = {
"sub": self._oauth_client_id,
"aud": self._oauth_issuer,
"iat": issued_at,
"exp": expires_at,
**self._key.jwt_payload,
}
logger.info(
"Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
)
self._cached_secret = jwt.encode(header, payload, self._key.key)
self._cached_secret_replacement_time = (
expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
)
return self._cached_secret


class OidcSessionTokenGenerator:
"""Methods for generating and checking OIDC Session cookies."""

Expand Down
Loading