Skip to content

Commit

Permalink
Add external command plugin to generate id tokens for identity aware …
Browse files Browse the repository at this point in the history
…proxy
  • Loading branch information
Fabio Grätz committed Aug 15, 2023
1 parent 0b13338 commit 77b3f8c
Show file tree
Hide file tree
Showing 8 changed files with 267 additions and 29 deletions.
71 changes: 49 additions & 22 deletions flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ def __init__(
endpoint_metadata: typing.Optional[EndpointMetadata] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[_requests.Session] = None,
request_auth_code_params: typing.Optional[typing.Dict[str, str]] = None,
request_access_token_params: typing.Optional[typing.Dict[str, str]] = None,
refresh_access_token_params: typing.Optional[typing.Dict[str, str]] = None,
add_request_auth_code_params_to_request_access_token_params: typing.Optional[bool] = False,
):
"""
Create new AuthorizationClient
Expand All @@ -193,7 +197,9 @@ def __init__(
:param auth_endpoint: str endpoint where auth metadata can be found
:param token_endpoint: str endpoint to retrieve token from
:param scopes: list[str] oauth2 scopes
:param client_id
:param client_id: oauth2 client id
:param redirect_uri: oauth2 redirect uri
:param endpoint_metadata: EndpointMetadata object to control the rendering of the page on login successful or failure
:param verify: (optional) Either a boolean, in which case it controls whether we verify
the server's TLS certificate, or a string, in which case it must be a path
to a CA bundle to use. Defaults to ``True``. When set to
Expand All @@ -204,6 +210,13 @@ def __init__(
may be useful during local development or testing.
:param session: (optional) A custom requests.Session object to use for making HTTP requests.
If not provided, a new Session object will be created.
:param request_auth_code_params: (optional) dict of parameters to add to login uri opened in the browser
:param request_access_token_params: (optional) dict of parameters to add when exchanging the auth code for the access token
:param refresh_access_token_params: (optional) dict of parameters to add when refreshing the access token
:param add_request_auth_code_params_to_request_access_token_params: Whether to add the `request_auth_code_params` to
the parameters sent when exchanging the auth code for the access token. Defaults to False.
Required e.g. for the PKCE flow with flyteadmin.
Not required for e.g. the standard OAuth2 flow on GCP.
"""
self._endpoint = endpoint
self._auth_endpoint = auth_endpoint
Expand All @@ -216,16 +229,13 @@ def __init__(
self._client_id = client_id
self._scopes = scopes or []
self._redirect_uri = redirect_uri
self._code_verifier = _generate_code_verifier()
code_challenge = _create_code_challenge(self._code_verifier)
self._code_challenge = code_challenge
state = _generate_state_parameter()
self._state = state
self._verify = verify
self._headers = {"content-type": "application/x-www-form-urlencoded"}
self._session = session or _requests.Session()

self._params = {
self._request_auth_code_params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
"response_type": "code", # Indicates the authorization code grant
"scope": " ".join(s.strip("' ") for s in self._scopes).strip(
Expand All @@ -234,10 +244,18 @@ def __init__(
# callback location where the user-agent will be directed to.
"redirect_uri": self._redirect_uri,
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}

if request_auth_code_params:
# Allow adding additional parameters to the request_auth_code_params
self._request_auth_code_params.update(request_auth_code_params)

self._request_access_token_params = request_access_token_params or {}
self._refresh_access_token_params = refresh_access_token_params or {}

if add_request_auth_code_params_to_request_access_token_params:
self._request_access_token_params.update(self._request_auth_code_params)

def __repr__(self):
return f"AuthorizationClient({self._auth_endpoint}, {self._token_endpoint}, {self._client_id}, {self._scopes}, {self._redirect_uri})"

Expand All @@ -253,7 +271,7 @@ def _create_callback_server(self):

def _request_authorization_code(self):
scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint)
query = _urlencode(self._params)
query = _urlencode(self._request_auth_code_params)
endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None))
logging.debug(f"Requesting authorization code through {endpoint}")
_webbrowser.open_new_tab(endpoint)
Expand All @@ -266,33 +284,38 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials:
"refresh_token": "bar",
"token_type": "Bearer"
}
Can additionally contain "expires_in" and "id_token" fields.
"""
response_body = auth_token_resp.json()
refresh_token = None
id_token = None
if "access_token" not in response_body:
raise ValueError('Expected "access_token" in response from oauth server')
if "refresh_token" in response_body:
refresh_token = response_body["refresh_token"]
if "expires_in" in response_body:
expires_in = response_body["expires_in"]
access_token = response_body["access_token"]
if "id_token" in response_body:
id_token = response_body["id_token"]

return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in)
return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in, id_token=id_token)

def _request_access_token(self, auth_code) -> Credentials:
if self._state != auth_code.state:
raise ValueError(f"Unexpected state parameter [{auth_code.state}] passed")
self._params.update(
{
"code": auth_code.code,
"code_verifier": self._code_verifier,
"grant_type": "authorization_code",
}
)

params = {
"code": auth_code.code,
"grant_type": "authorization_code",
}

params.update(self._request_access_token_params)

resp = self._session.post(
url=self._token_endpoint,
data=self._params,
data=params,
headers=self._headers,
allow_redirects=False,
verify=self._verify,
Expand Down Expand Up @@ -336,13 +359,17 @@ def refresh_access_token(self, credentials: Credentials) -> Credentials:
if credentials.refresh_token is None:
raise ValueError("no refresh token available with which to refresh authorization credentials")

data = {
"refresh_token": credentials.refresh_token,
"grant_type": "refresh_token",
"client_id": self._client_id,
}

data.update(self._refresh_access_token_params)

resp = self._session.post(
url=self._token_endpoint,
data={
"grant_type": "refresh_token",
"client_id": self._client_id,
"refresh_token": credentials.refresh_token,
},
data=data,
headers=self._headers,
allow_redirects=False,
verify=self._verify,
Expand Down
15 changes: 15 additions & 0 deletions flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def __init__(

def _initialize_auth_client(self):
if not self._auth_client:

from .auth_client import _create_code_challenge, _generate_code_verifier

code_verifier = _generate_code_verifier()
code_challenge = _create_code_challenge(code_verifier)

cfg = self._cfg_store.get_client_config()
self._set_header_key(cfg.header_key)
self._auth_client = AuthorizationClient(
Expand All @@ -119,6 +125,15 @@ def _initialize_auth_client(self):
token_endpoint=cfg.token_endpoint,
verify=self._verify,
session=self._session,
request_auth_code_params={
"code_challenge": code_challenge,
"code_challenge_method": "S256",
},
request_access_token_params={
"code_verifier": code_verifier,
},
refresh_access_token_params={},
add_request_auth_code_params_to_request_access_token_params=True,
)

def refresh_credentials(self):
Expand Down
28 changes: 21 additions & 7 deletions flytekit/clients/auth/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass

import keyring as _keyring
from keyring.errors import NoKeyringError
from keyring.errors import NoKeyringError, PasswordDeleteError


@dataclass
Expand All @@ -16,6 +16,7 @@ class Credentials(object):
refresh_token: str = "na"
for_endpoint: str = "flyte-default"
expires_in: typing.Optional[int] = None
id_token: typing.Optional[str] = None


class KeyringStore:
Expand All @@ -25,20 +26,28 @@ class KeyringStore:

_access_token_key = "access_token"
_refresh_token_key = "refresh_token"
_id_token_key = "id_token"

@staticmethod
def store(credentials: Credentials) -> Credentials:
try:
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._refresh_token_key,
credentials.refresh_token,
)
if credentials.refresh_token:
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._refresh_token_key,
credentials.refresh_token,
)
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._access_token_key,
credentials.access_token,
)
if credentials.id_token:
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._id_token_key,
credentials.id_token,
)
except NoKeyringError as e:
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
return credentials
Expand All @@ -48,18 +57,23 @@ def retrieve(for_endpoint: str) -> typing.Optional[Credentials]:
try:
refresh_token = _keyring.get_password(for_endpoint, KeyringStore._refresh_token_key)
access_token = _keyring.get_password(for_endpoint, KeyringStore._access_token_key)
id_token = _keyring.get_password(for_endpoint, KeyringStore._id_token_key)
except NoKeyringError as e:
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
return None

if not access_token:
return None
return Credentials(access_token, refresh_token, for_endpoint)
return Credentials(access_token, refresh_token, for_endpoint, id_token=id_token)

@staticmethod
def delete(for_endpoint: str):
try:
_keyring.delete_password(for_endpoint, KeyringStore._access_token_key)
_keyring.delete_password(for_endpoint, KeyringStore._refresh_token_key)
try:
_keyring.delete_password(for_endpoint, KeyringStore._id_token_key)
except PasswordDeleteError as e:
logging.debug(f"Id token not found in key store, not deleting. Error: {e}")
except NoKeyringError as e:
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import logging
import typing

import click

from flytekit.clients.auth.auth_client import AuthorizationClient
from flytekit.clients.auth.authenticator import Authenticator
from flytekit.clients.auth.exceptions import AccessTokenNotFoundError
from flytekit.clients.auth.keyring import KeyringStore


class GCPIdentityAwareProxyAuthenticator(Authenticator):
"""
This Authenticator encapsulates the entire OAauth 2.0 flow with GCP Identity Aware Proxy.
The auth flow is described in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application
Automatically opens a browser window for login.
"""

def __init__(
self,
audience: str,
client_id: str,
client_secret: str,
verify: typing.Optional[typing.Union[bool, str]] = None,
):
"""
Initialize with default creds from KeyStore using the audience name.
"""
super().__init__(audience, "proxy-authorization", KeyringStore.retrieve(audience), verify=verify)
self._auth_client = None

self.audience = audience
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = "http://localhost:4444"

def _initialize_auth_client(self):
if not self._auth_client:
self._auth_client = AuthorizationClient(
endpoint=self.audience,
# See step 3 in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application
auth_endpoint="https://accounts.google.com/o/oauth2/v2/auth",
token_endpoint="https://oauth2.googleapis.com/token",
# See step 3 in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application
scopes=["openid", "email"],
client_id=self.client_id,
redirect_uri=self.redirect_uri,
verify=self._verify,
# See step 3 in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application
request_auth_code_params={
"cred_ref": "true",
"access_type": "offline",
},
# See step 4 in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application
request_access_token_params={
"client_id": self.client_id,
"client_secret": self.client_secret,
"audience": self.audience,
"redirect_uri": self.redirect_uri,
},
# See https://cloud.google.com/iap/docs/authentication-howto#refresh_token
refresh_access_token_params={
"client_secret": self.client_secret,
"audience": self.audience,
},
)

def refresh_credentials(self):
"""Refresh the IAP credentials. If no credentials are found, it will kick off a full OAuth 2.0 authorization flow."""
self._initialize_auth_client()
if self._creds:
"""We have an id token so lets try to refresh it"""
try:
self._creds = self._auth_client.refresh_access_token(self._creds)
if self._creds:
KeyringStore.store(self._creds)
return
except AccessTokenNotFoundError:
logging.warning("Failed to refresh token. Kicking off a full authorization flow.")
KeyringStore.delete(self._endpoint)

self._creds = self._auth_client.get_creds_from_remote()
KeyringStore.store(self._creds)


@click.command()
@click.option(
"--desktop_client_id",
type=str,
default=None,
required=True,
help=(
"Desktop type OAuth 2.0 client ID. Typically in the form of `<xyz>.apps.googleusercontent.com`. "
"Create by following https://cloud.google.com/iap/docs/authentication-howto#setting_up_the_client_id"
),
)
@click.option(
"--desktop_client_secret_gcp_secret_name",
type=str,
default=None,
required=True,
help=(
"Name of a GCP secret manager secret containing the desktop type OAuth 2.0 client secret "
"obtained together with desktop type OAuth 2.0 client ID."
),
)
@click.option(
"--webapp_client_id",
type=str,
default=None,
required=True,
help=(
"Webapp type OAuth 2.0 client ID. Typically in the form of `<xyz>.apps.googleusercontent.com`. "
"Created when activating IAP for the Flyte deployment. "
"https://cloud.google.com/iap/docs/enabling-kubernetes-howto#oauth-credentials"
),
)
def flyte_iap_token(desktop_client_id: str, desktop_client_secret_gcp_secret_name: str, webapp_client_id: str):
"""Generate an ID token for proxy-authentication/authorization with GCP Identity Aware Proxy."""
desktop_client_secret = desktop_client_secret_gcp_secret_name # TODO

iap_authenticator = GCPIdentityAwareProxyAuthenticator(
audience=webapp_client_id,
client_id=desktop_client_id,
client_secret=desktop_client_secret,
)
try:
iap_authenticator.refresh_credentials()
except Exception as e:
raise click.ClickException(f"Failed to obtain credentials for GCP Identity Aware Proxy (IAP): {e}")

click.echo(iap_authenticator.get_credentials().id_token)


if __name__ == "__main__":
flyte_iap_token()
Loading

0 comments on commit 77b3f8c

Please sign in to comment.