Skip to content

Commit

Permalink
Feat: Enable flytekit to authenticate with proxy in front of FlyteA…
Browse files Browse the repository at this point in the history
…dmin (#1787)

* Introduce authenticator engine and make proxy auth work

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Use proxy authed session for client credentials flow

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Don't use authenticator engine but do proxy authentication via existing external command authenticator

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Add docstring to AuthenticationHTTPAdapter

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Address todo in docstring

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Create blank session if none provided

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Create blank session if none provided in get_token

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Refresh proxy creds in session when not existing without triggering 401

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Add test for get_session

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Move auth helper test into existing module

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Move auth helper test into existing module

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Add test for upgrade_channel_to_proxy_authenticated

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Auth helper tests without use of responses package

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Feat: Add plugin for generating GCP IAP ID tokens via external command (#1795)

* Add external command plugin to generate id tokens for identity aware proxy

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Retrieve desktop app client secret from gcp secret manager

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Remove comments

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Introduce a command group that allows adding a command to generate service account id tokens later

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Document how to use plugin and deploy Flyte with IAP

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Minor corrections README.md

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

---------

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>
Co-authored-by: Fabio Grätz <fabiogratz@googlemail.com>
Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Use proxy auth'ed session for device code auth flow

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Fix token client tests

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Make poll token endpoint test more specific

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Make test_client_creds_authenticator test work and more specific

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Make test_client_creds_authenticator_with_custom_scopes test work and more specific

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Implement subcommand to generate id tokens for service accounts

Signed-off-by: Fabio Graetz <fabiograetz@googlemail.com>

* Test id token generation from service accounts

Signed-off-by: Fabio Graetz <fabiograetz@googlemail.com>

* Fix plugin requirements

Signed-off-by: Fabio Graetz <fabiograetz@googlemail.com>

* Document usage of generate-service-account-id-token subcommand

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Document alternative ways to obtain service account id tokens

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

---------

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>
Signed-off-by: Fabio Graetz <fabiograetz@googlemail.com>
Co-authored-by: Fabio Grätz <fabiogratz@googlemail.com>
  • Loading branch information
fg91 and Fabio Grätz committed Sep 20, 2023
1 parent cf165f7 commit cdcba2f
Show file tree
Hide file tree
Showing 18 changed files with 1,155 additions and 65 deletions.
79 changes: 55 additions & 24 deletions flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ def __init__(
redirect_uri: typing.Optional[str] = None,
endpoint_metadata: typing.Optional[EndpointMetadata] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[_requests.Session] = None,
request_auth_code_params: typing.Optional[typing.Dict[str, str]] = None,
request_access_token_params: typing.Optional[typing.Dict[str, str]] = None,
refresh_access_token_params: typing.Optional[typing.Dict[str, str]] = None,
add_request_auth_code_params_to_request_access_token_params: typing.Optional[bool] = False,
):
"""
Create new AuthorizationClient
Expand All @@ -192,7 +197,9 @@ def __init__(
:param auth_endpoint: str endpoint where auth metadata can be found
:param token_endpoint: str endpoint to retrieve token from
:param scopes: list[str] oauth2 scopes
:param client_id
:param client_id: oauth2 client id
:param redirect_uri: oauth2 redirect uri
:param endpoint_metadata: EndpointMetadata object to control the rendering of the page on login successful or failure
:param verify: (optional) Either a boolean, in which case it controls whether we verify
the server's TLS certificate, or a string, in which case it must be a path
to a CA bundle to use. Defaults to ``True``. When set to
Expand All @@ -201,6 +208,15 @@ def __init__(
certificates, which will make your application vulnerable to
man-in-the-middle (MitM) attacks. Setting verify to ``False``
may be useful during local development or testing.
:param session: (optional) A custom requests.Session object to use for making HTTP requests.
If not provided, a new Session object will be created.
:param request_auth_code_params: (optional) dict of parameters to add to login uri opened in the browser
:param request_access_token_params: (optional) dict of parameters to add when exchanging the auth code for the access token
:param refresh_access_token_params: (optional) dict of parameters to add when refreshing the access token
:param add_request_auth_code_params_to_request_access_token_params: Whether to add the `request_auth_code_params` to
the parameters sent when exchanging the auth code for the access token. Defaults to False.
Required e.g. for the PKCE flow with flyteadmin.
Not required for e.g. the standard OAuth2 flow on GCP.
"""
self._endpoint = endpoint
self._auth_endpoint = auth_endpoint
Expand All @@ -213,15 +229,13 @@ def __init__(
self._client_id = client_id
self._scopes = scopes or []
self._redirect_uri = redirect_uri
self._code_verifier = _generate_code_verifier()
code_challenge = _create_code_challenge(self._code_verifier)
self._code_challenge = code_challenge
state = _generate_state_parameter()
self._state = state
self._verify = verify
self._headers = {"content-type": "application/x-www-form-urlencoded"}
self._session = session or _requests.Session()

self._params = {
self._request_auth_code_params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
"response_type": "code", # Indicates the authorization code grant
"scope": " ".join(s.strip("' ") for s in self._scopes).strip(
Expand All @@ -230,10 +244,18 @@ def __init__(
# callback location where the user-agent will be directed to.
"redirect_uri": self._redirect_uri,
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}

if request_auth_code_params:
# Allow adding additional parameters to the request_auth_code_params
self._request_auth_code_params.update(request_auth_code_params)

self._request_access_token_params = request_access_token_params or {}
self._refresh_access_token_params = refresh_access_token_params or {}

if add_request_auth_code_params_to_request_access_token_params:
self._request_access_token_params.update(self._request_auth_code_params)

def __repr__(self):
return f"AuthorizationClient({self._auth_endpoint}, {self._token_endpoint}, {self._client_id}, {self._scopes}, {self._redirect_uri})"

Expand All @@ -249,7 +271,7 @@ def _create_callback_server(self):

def _request_authorization_code(self):
scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint)
query = _urlencode(self._params)
query = _urlencode(self._request_auth_code_params)
endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None))
logging.debug(f"Requesting authorization code through {endpoint}")
_webbrowser.open_new_tab(endpoint)
Expand All @@ -262,33 +284,38 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials:
"refresh_token": "bar",
"token_type": "Bearer"
}
Can additionally contain "expires_in" and "id_token" fields.
"""
response_body = auth_token_resp.json()
refresh_token = None
id_token = None
if "access_token" not in response_body:
raise ValueError('Expected "access_token" in response from oauth server')
if "refresh_token" in response_body:
refresh_token = response_body["refresh_token"]
if "expires_in" in response_body:
expires_in = response_body["expires_in"]
access_token = response_body["access_token"]
if "id_token" in response_body:
id_token = response_body["id_token"]

return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in)
return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in, id_token=id_token)

def _request_access_token(self, auth_code) -> Credentials:
if self._state != auth_code.state:
raise ValueError(f"Unexpected state parameter [{auth_code.state}] passed")
self._params.update(
{
"code": auth_code.code,
"code_verifier": self._code_verifier,
"grant_type": "authorization_code",
}
)

resp = _requests.post(
params = {
"code": auth_code.code,
"grant_type": "authorization_code",
}

params.update(self._request_access_token_params)

resp = self._session.post(
url=self._token_endpoint,
data=self._params,
data=params,
headers=self._headers,
allow_redirects=False,
verify=self._verify,
Expand Down Expand Up @@ -332,13 +359,17 @@ def refresh_access_token(self, credentials: Credentials) -> Credentials:
if credentials.refresh_token is None:
raise ValueError("no refresh token available with which to refresh authorization credentials")

resp = _requests.post(
data = {
"refresh_token": credentials.refresh_token,
"grant_type": "refresh_token",
"client_id": self._client_id,
}

data.update(self._refresh_access_token_params)

resp = self._session.post(
url=self._token_endpoint,
data={
"grant_type": "refresh_token",
"client_id": self._client_id,
"refresh_token": credentials.refresh_token,
},
data=data,
headers=self._headers,
allow_redirects=False,
verify=self._verify,
Expand Down
32 changes: 31 additions & 1 deletion flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass

import click
import requests

from . import token_client
from .auth_client import AuthorizationClient
Expand Down Expand Up @@ -95,16 +96,24 @@ def __init__(
cfg_store: ClientConfigStore,
header_key: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[requests.Session] = None,
):
"""
Initialize with default creds from KeyStore using the endpoint name
"""
super().__init__(endpoint, header_key, KeyringStore.retrieve(endpoint), verify=verify)
self._cfg_store = cfg_store
self._auth_client = None
self._session = session or requests.Session()

def _initialize_auth_client(self):
if not self._auth_client:

from .auth_client import _create_code_challenge, _generate_code_verifier

code_verifier = _generate_code_verifier()
code_challenge = _create_code_challenge(code_verifier)

cfg = self._cfg_store.get_client_config()
self._set_header_key(cfg.header_key)
self._auth_client = AuthorizationClient(
Expand All @@ -115,6 +124,16 @@ def _initialize_auth_client(self):
auth_endpoint=cfg.authorization_endpoint,
token_endpoint=cfg.token_endpoint,
verify=self._verify,
session=self._session,
request_auth_code_params={
"code_challenge": code_challenge,
"code_challenge_method": "S256",
},
request_access_token_params={
"code_verifier": code_verifier,
},
refresh_access_token_params={},
add_request_auth_code_params_to_request_access_token_params=True,
)

def refresh_credentials(self):
Expand Down Expand Up @@ -176,6 +195,7 @@ def __init__(
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
audience: typing.Optional[str] = None,
session: typing.Optional[requests.Session] = None,
):
if not client_id or not client_secret:
raise ValueError("Client ID and Client SECRET both are required.")
Expand All @@ -186,6 +206,7 @@ def __init__(
self._client_id = client_id
self._client_secret = client_secret
self._audience = audience or cfg.audience
self._session = session or requests.Session()
super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url, verify=verify)

def refresh_credentials(self):
Expand All @@ -211,6 +232,7 @@ def refresh_credentials(self):
verify=self._verify,
scopes=scopes,
audience=audience,
session=self._session,
)

logging.info("Retrieved new token, expires in {}".format(expires_in))
Expand All @@ -234,6 +256,7 @@ def __init__(
audience: typing.Optional[str] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[requests.Session] = None,
):
self._audience = audience
cfg = cfg_store.get_client_config()
Expand All @@ -245,6 +268,7 @@ def __init__(
raise AuthenticationError(
"Device Authentication is not available on the Flyte backend / authentication server"
)
self._session = session or requests.Session()
super().__init__(
endpoint=endpoint,
header_key=header_key or cfg.header_key,
Expand All @@ -255,7 +279,13 @@ def __init__(

def refresh_credentials(self):
resp = token_client.get_device_code(
self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url, self._verify
self._device_auth_endpoint,
self._client_id,
self._audience,
self._scope,
self._http_proxy_url,
self._verify,
self._session,
)
text = f"To Authenticate, navigate in a browser to the following URL: {click.style(resp.verification_uri, fg='blue', underline=True)} and enter code: {click.style(resp.user_code, fg='blue')}"
click.secho(text)
Expand Down
30 changes: 22 additions & 8 deletions flytekit/clients/auth/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass

import keyring as _keyring
from keyring.errors import NoKeyringError
from keyring.errors import NoKeyringError, PasswordDeleteError


@dataclass
Expand All @@ -16,6 +16,7 @@ class Credentials(object):
refresh_token: str = "na"
for_endpoint: str = "flyte-default"
expires_in: typing.Optional[int] = None
id_token: typing.Optional[str] = None


class KeyringStore:
Expand All @@ -25,20 +26,28 @@ class KeyringStore:

_access_token_key = "access_token"
_refresh_token_key = "refresh_token"
_id_token_key = "id_token"

@staticmethod
def store(credentials: Credentials) -> Credentials:
try:
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._refresh_token_key,
credentials.refresh_token,
)
if credentials.refresh_token:
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._refresh_token_key,
credentials.refresh_token,
)
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._access_token_key,
credentials.access_token,
)
if credentials.id_token:
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._id_token_key,
credentials.id_token,
)
except NoKeyringError as e:
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
return credentials
Expand All @@ -48,18 +57,23 @@ def retrieve(for_endpoint: str) -> typing.Optional[Credentials]:
try:
refresh_token = _keyring.get_password(for_endpoint, KeyringStore._refresh_token_key)
access_token = _keyring.get_password(for_endpoint, KeyringStore._access_token_key)
id_token = _keyring.get_password(for_endpoint, KeyringStore._id_token_key)
except NoKeyringError as e:
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
return None

if not access_token:
if not access_token and not id_token:
return None
return Credentials(access_token, refresh_token, for_endpoint)
return Credentials(access_token, refresh_token, for_endpoint, id_token=id_token)

@staticmethod
def delete(for_endpoint: str):
try:
_keyring.delete_password(for_endpoint, KeyringStore._access_token_key)
_keyring.delete_password(for_endpoint, KeyringStore._refresh_token_key)
try:
_keyring.delete_password(for_endpoint, KeyringStore._id_token_key)
except PasswordDeleteError as e:
logging.debug(f"Id token not found in key store, not deleting. Error: {e}")
except NoKeyringError as e:
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
11 changes: 9 additions & 2 deletions flytekit/clients/auth/token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_token(
grant_type: GrantType = GrantType.CLIENT_CREDS,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[requests.Session] = None,
) -> typing.Tuple[str, int]:
"""
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
Expand All @@ -103,7 +104,10 @@ def get_token(
body["audience"] = audience

proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify)

if not session:
session = requests.Session()
response = session.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify)

if not response.ok:
j = response.json()
Expand All @@ -125,6 +129,7 @@ def get_device_code(
scope: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[requests.Session] = None,
) -> DeviceCodeResponse:
"""
Retrieves the device Authentication code that can be done to authenticate the request using a browser on a
Expand All @@ -133,7 +138,9 @@ def get_device_code(
_scope = " ".join(scope) if scope is not None else ""
payload = {"client_id": client_id, "scope": _scope, "audience": audience}
proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
resp = requests.post(device_auth_endpoint, payload, proxies=proxies, verify=verify)
if not session:
session = requests.Session()
resp = session.post(device_auth_endpoint, payload, proxies=proxies, verify=verify)
if not resp.ok:
raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}")
return DeviceCodeResponse.from_json_response(resp.json())
Expand Down
Loading

0 comments on commit cdcba2f

Please sign in to comment.