Skip to content

Commit

Permalink
Core: Allow customizing OAuth scope
Browse files Browse the repository at this point in the history
  • Loading branch information
nastra committed Jan 18, 2023
1 parent 35151fe commit 64c1da5
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,12 @@ public void initialize(String name, Map<String, String> unresolved) {
ConfigResponse config;
OAuthTokenResponse authResponse;
String credential = props.get(OAuth2Properties.CREDENTIAL);
// TODO: if scope can be overridden, it should be done consistently
String scope = props.getOrDefault(OAuth2Properties.SCOPE, OAuth2Properties.CATALOG_SCOPE);
try (RESTClient initClient = clientBuilder.apply(props)) {
Map<String, String> initHeaders =
RESTUtil.merge(configHeaders(props), OAuth2Util.authHeaders(initToken));
if (credential != null && !credential.isEmpty()) {
authResponse =
OAuth2Util.fetchToken(
initClient, initHeaders, credential, OAuth2Properties.CATALOG_SCOPE);
authResponse = OAuth2Util.fetchToken(initClient, initHeaders, credential, scope);
Map<String, String> authHeaders =
RESTUtil.merge(initHeaders, OAuth2Util.authHeaders(authResponse.token()));
config = fetchConfig(initClient, authHeaders, props);
Expand All @@ -166,7 +164,7 @@ public void initialize(String name, Map<String, String> unresolved) {
this.client = clientBuilder.apply(mergedProps);
this.paths = ResourcePaths.forCatalogProperties(mergedProps);

this.catalogAuth = new AuthSession(baseHeaders, null, null, credential);
this.catalogAuth = new AuthSession(baseHeaders, null, null, credential, scope);
if (authResponse != null) {
this.catalogAuth =
AuthSession.fromTokenResponse(
Expand Down
32 changes: 25 additions & 7 deletions core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -357,19 +357,25 @@ public static class AuthSession {
private String tokenType;
private Long expiresAtMillis;
private final String credential;
private final String scope;
private volatile boolean keepRefreshed = true;

public AuthSession(Map<String, String> baseHeaders, String token, String tokenType) {
this(baseHeaders, token, tokenType, null);
this(baseHeaders, token, tokenType, null, OAuth2Properties.CATALOG_SCOPE);
}

public AuthSession(
Map<String, String> baseHeaders, String token, String tokenType, String credential) {
Map<String, String> baseHeaders,
String token,
String tokenType,
String credential,
String scope) {
this.headers = RESTUtil.merge(baseHeaders, authHeaders(token));
this.token = token;
this.tokenType = tokenType;
this.expiresAtMillis = OAuth2Util.expiresAtMillis(token);
this.credential = credential;
this.scope = scope;
}

public Map<String, String> headers() {
Expand All @@ -388,6 +394,10 @@ public Long expiresAtMillis() {
return expiresAtMillis;
}

public String scope() {
return scope;
}

public void stopRefreshing() {
this.keepRefreshed = false;
}
Expand All @@ -402,7 +412,7 @@ public String credential() {
* @return A new {@link AuthSession} with empty headers.
*/
public static AuthSession empty() {
return new AuthSession(ImmutableMap.of(), null, null, null);
return new AuthSession(ImmutableMap.of(), null, null, null, OAuth2Properties.CATALOG_SCOPE);
}

/**
Expand Down Expand Up @@ -457,7 +467,7 @@ private OAuthTokenResponse refreshCurrentToken(RESTClient client) {
return refreshExpiredToken(client);
} else {
// attempt a normal refresh
return refreshToken(client, headers(), token, tokenType, OAuth2Properties.CATALOG_SCOPE);
return refreshToken(client, headers(), token, tokenType, scope);
}
}

Expand All @@ -468,7 +478,7 @@ static boolean isExpired(Long expiresAtMillis, long now) {
private OAuthTokenResponse refreshExpiredToken(RESTClient client) {
if (credential != null) {
Map<String, String> basicHeaders = RESTUtil.merge(headers(), basicAuthHeaders(credential));
return refreshToken(client, basicHeaders, token, tokenType, OAuth2Properties.CATALOG_SCOPE);
return refreshToken(client, basicHeaders, token, tokenType, scope);
}

return null;
Expand Down Expand Up @@ -554,7 +564,11 @@ public static AuthSession fromAccessToken(
AuthSession parent) {
AuthSession session =
new AuthSession(
parent.headers(), token, OAuth2Properties.ACCESS_TOKEN_TYPE, parent.credential());
parent.headers(),
token,
OAuth2Properties.ACCESS_TOKEN_TYPE,
parent.credential(),
parent.scope());

long startTimeMillis = System.currentTimeMillis();
Long expiresAtMillis = session.expiresAtMillis();
Expand Down Expand Up @@ -610,7 +624,11 @@ private static AuthSession fromTokenResponse(
String credential) {
AuthSession session =
new AuthSession(
parent.headers(), response.token(), response.issuedTokenType(), credential);
parent.headers(),
response.token(),
response.issuedTokenType(),
credential,
parent.scope());
if (response.expiresInSeconds() != null) {
scheduleTokenRefresh(
client,
Expand Down
15 changes: 12 additions & 3 deletions core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.apache.iceberg.metrics.MetricsReporter;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod;
import org.apache.iceberg.rest.auth.OAuth2Properties;
import org.apache.iceberg.rest.auth.OAuth2Util;
import org.apache.iceberg.rest.responses.ConfigResponse;
import org.apache.iceberg.rest.responses.ErrorResponse;
Expand Down Expand Up @@ -918,8 +919,16 @@ public void testCatalogTokenRefresh() throws Exception {
UUID.randomUUID().toString(), "user", contextCredentials, ImmutableMap.of());

RESTCatalog catalog = new RESTCatalog(context, (config) -> adapter);
String scope = "custom_catalog_scope";
catalog.initialize(
"prod", ImmutableMap.of(CatalogProperties.URI, "ignored", "credential", "catalog:secret"));
"prod",
ImmutableMap.of(
CatalogProperties.URI,
"ignored",
"credential",
"catalog:secret",
OAuth2Properties.SCOPE,
scope));

Thread.sleep(3_000); // sleep until after 2 refresh calls

Expand Down Expand Up @@ -951,7 +960,7 @@ public void testCatalogTokenRefresh() throws Exception {
"grant_type", "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token", "client-credentials-token:sub=catalog",
"subject_token_type", "urn:ietf:params:oauth:token-type:access_token",
"scope", "catalog");
"scope", scope);
Mockito.verify(adapter)
.execute(
eq(HTTPMethod.POST),
Expand All @@ -968,7 +977,7 @@ public void testCatalogTokenRefresh() throws Exception {
"grant_type", "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token", "token-exchange-token:sub=client-credentials-token:sub=catalog",
"subject_token_type", "urn:ietf:params:oauth:token-type:access_token",
"scope", "catalog");
"scope", scope);
Map<String, String> secondRefreshHeaders =
ImmutableMap.of(
"Authorization",
Expand Down

0 comments on commit 64c1da5

Please sign in to comment.