Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make optional oauth configurable #486

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mkdocs/docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ catalog:
| credential | t-1234:secret | Credential to use for OAuth2 credential flow when initializing the catalog |
| token | FEW23.DFSDF.FSDF | Bearer token value to use for `Authorization` header |
| scope | openid offline corpds:ds:profile | Desired scope of the requested security token (default : catalog) |
| resource | rest_catalog.iceberg.com | URI for the target resource or service |
| audience | rest_catalog | Logical name of target resource or service |
| rest.sigv4-enabled | true | Sign requests to the REST Server using AWS SigV4 protocol |
| rest.signing-region | us-east-1 | The region to use when SigV4 signing a request |
| rest.signing-name | execute-api | The service signing name to use when SigV4 signing a request |
Expand Down
18 changes: 15 additions & 3 deletions pyiceberg/catalog/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class Endpoints:
CREDENTIAL = "credential"
GRANT_TYPE = "grant_type"
SCOPE = "scope"
AUDIENCE = "audience"
RESOURCE = "resource"
Fokko marked this conversation as resolved.
Show resolved Hide resolved
TOKEN_EXCHANGE = "urn:ietf:params:oauth:grant-type:token-exchange"
SEMICOLON = ":"
KEY = "key"
Expand Down Expand Up @@ -289,16 +291,26 @@ def auth_url(self) -> str:
else:
return self.url(Endpoints.get_token, prefixed=False)

def _extract_optional_oauth_params(self) -> Dict[str, str]:
optional_oauth_param = {SCOPE: self.properties.get(SCOPE) or CATALOG_SCOPE}
set_of_optional_params = {AUDIENCE, RESOURCE}
Fokko marked this conversation as resolved.
Show resolved Hide resolved
for param in set_of_optional_params:
if param_value := self.properties.get(param):
optional_oauth_param[param] = param_value

return optional_oauth_param

def _fetch_access_token(self, session: Session, credential: str) -> str:
if SEMICOLON in credential:
client_id, client_secret = credential.split(SEMICOLON)
else:
client_id, client_secret = None, credential

# take scope from properties or use default CATALOG_SCOPE
scope = self.properties.get(SCOPE) or CATALOG_SCOPE
data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret}

optional_oauth_params = self._extract_optional_oauth_params()
data.update(optional_oauth_params)

data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret, SCOPE: scope}
response = session.post(
url=self.auth_url, data=data, headers={**session.headers, "Content-type": "application/x-www-form-urlencoded"}
)
Expand Down
45 changes: 45 additions & 0 deletions tests/catalog/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
TEST_AUTH_URL = "https://auth-endpoint/"
TEST_TOKEN = "some_jwt_token"
TEST_SCOPE = "openid_offline_corpds_ds_profile"
TEST_AUDIENCE = "test_audience"
TEST_RESOURCE = "test_resource"

TEST_HEADERS = {
"Content-type": "application/json",
"X-Client-Version": "0.14.1",
Expand Down Expand Up @@ -137,6 +140,48 @@ def test_token_200_without_optional_fields(rest_mock: Mocker) -> None:
)


def test_token_with_optional_oauth_params(rest_mock: Mocker) -> None:
mock_request = rest_mock.post(
f"{TEST_URI}v1/oauth/tokens",
json={
"access_token": TEST_TOKEN,
"token_type": "Bearer",
"expires_in": 86400,
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
},
status_code=200,
request_headers=OAUTH_TEST_HEADERS,
)
assert (
RestCatalog(
"rest", uri=TEST_URI, credential=TEST_CREDENTIALS, audience=TEST_AUDIENCE, resource=TEST_RESOURCE
)._session.headers["Authorization"]
== f"Bearer {TEST_TOKEN}"
)
assert TEST_AUDIENCE in mock_request.last_request.text
assert TEST_RESOURCE in mock_request.last_request.text


def test_token_with_optional_oauth_params_as_empty(rest_mock: Mocker) -> None:
mock_request = rest_mock.post(
f"{TEST_URI}v1/oauth/tokens",
json={
"access_token": TEST_TOKEN,
"token_type": "Bearer",
"expires_in": 86400,
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
},
status_code=200,
request_headers=OAUTH_TEST_HEADERS,
)
assert (
RestCatalog("rest", uri=TEST_URI, credential=TEST_CREDENTIALS, audience="", resource="")._session.headers["Authorization"]
== f"Bearer {TEST_TOKEN}"
)
assert TEST_AUDIENCE not in mock_request.last_request.text
assert TEST_RESOURCE not in mock_request.last_request.text


def test_token_with_default_scope(rest_mock: Mocker) -> None:
mock_request = rest_mock.post(
f"{TEST_URI}v1/oauth/tokens",
Expand Down