diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 0705a21fb9..8e5ccf2fe2 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -418,19 +418,31 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None kwargs, "client_credentials_secret", _internal.Credentials.CLIENT_CREDENTIALS_SECRET.read(config_file) ) + is_client_secret = False client_credentials_secret = read_file_if_exists( _internal.Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(config_file) ) - if client_credentials_secret and client_credentials_secret.endswith("\n"): - logger.info("Newline stripped from client secret") - client_credentials_secret = client_credentials_secret.strip() + if client_credentials_secret: + is_client_secret = True + if client_credentials_secret.endswith("\n"): + logger.info("Newline stripped from client secret") + client_credentials_secret = client_credentials_secret.strip() kwargs = set_if_exists( kwargs, "client_credentials_secret", client_credentials_secret, ) + + client_credentials_secret_env_var = _internal.Credentials.CLIENT_CREDENTIALS_SECRET_ENV_VAR.read(config_file) + if client_credentials_secret_env_var: + client_credentials_secret = os.getenv(client_credentials_secret_env_var) + if client_credentials_secret: + is_client_secret = True + kwargs = set_if_exists(kwargs, "client_credentials_secret", client_credentials_secret) kwargs = set_if_exists(kwargs, "scopes", _internal.Credentials.SCOPES.read(config_file)) kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file)) + if is_client_secret: + kwargs = set_if_exists(kwargs, "auth_mode", AuthType.CLIENTSECRET.value) kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file)) kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file)) diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index 4f993b4e11..f34321f57b 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -85,6 +85,14 @@ class Credentials(object): password from a mounted file. """ + CLIENT_CREDENTIALS_SECRET_ENV_VAR = ConfigEntry( + LegacyConfigEntry(SECTION, "client_secret_env_var"), YamlConfigEntry("admin.clientSecretEnvVar") + ) + """ + Used for basic auth, which is automatically called during pyflyte. This will allow the Flyte engine to read the + password from a mounted environment variable. + """ + SCOPES = ConfigEntry(LegacyConfigEntry(SECTION, "scopes", list), YamlConfigEntry("admin.scopes", list)) AUTH_MODE = ConfigEntry(LegacyConfigEntry(SECTION, "auth_mode"), YamlConfigEntry("admin.authType")) diff --git a/tests/flytekit/unit/configuration/configs/creds_secret_env_var.yaml b/tests/flytekit/unit/configuration/configs/creds_secret_env_var.yaml new file mode 100644 index 0000000000..e0d4748460 --- /dev/null +++ b/tests/flytekit/unit/configuration/configs/creds_secret_env_var.yaml @@ -0,0 +1,13 @@ +admin: + # For GRPC endpoints you might want to use dns:///flyte.myexample.com + endpoint: dns:///flyte.mycorp.io + clientSecretEnvVar: FAKE_SECRET_NAME + insecure: true + clientId: propeller + scopes: + - all +storage: + connection: + access-key: minio + endpoint: http://localhost:30084 + secret-key: miniostorage diff --git a/tests/flytekit/unit/configuration/test_internal.py b/tests/flytekit/unit/configuration/test_internal.py index 97e30b5612..5c1da14a5b 100644 --- a/tests/flytekit/unit/configuration/test_internal.py +++ b/tests/flytekit/unit/configuration/test_internal.py @@ -2,7 +2,7 @@ import mock -from flytekit.configuration import PlatformConfig, get_config_file, read_file_if_exists +from flytekit.configuration import AuthType, PlatformConfig, get_config_file, read_file_if_exists from flytekit.configuration.internal import AWS, Credentials, Images @@ -45,6 +45,25 @@ def test_client_secret_location(): # Assert that secret in platform config does not contain a newline platform_cfg = PlatformConfig.auto(cfg) assert platform_cfg.client_credentials_secret == "hello" + assert platform_cfg.auth_mode == AuthType.CLIENTSECRET.value + + +@mock.patch.dict("os.environ") +def test_client_secret_env_var(): + cfg = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/sample.yaml")) + secret_env_var = Credentials.CLIENT_CREDENTIALS_SECRET_ENV_VAR.read(cfg) + assert secret_env_var is None + + cfg = get_config_file( + os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/creds_secret_env_var.yaml") + ) + secret_env_var = Credentials.CLIENT_CREDENTIALS_SECRET_ENV_VAR.read(cfg) + assert secret_env_var == "FAKE_SECRET_NAME" + + os.environ["FAKE_SECRET_NAME"] = "fake_secret_value" + platform_cfg = PlatformConfig.auto(cfg) + assert platform_cfg.client_credentials_secret == "fake_secret_value" + assert platform_cfg.auth_mode == AuthType.CLIENTSECRET.value def test_read_file_if_exists():