Skip to content

Commit

Permalink
Core: Allow customizing OAuth scope
Browse files Browse the repository at this point in the history
  • Loading branch information
nastra committed Jan 23, 2023
1 parent 4491e7d commit 909d3bb
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 30 deletions.
21 changes: 5 additions & 16 deletions core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,12 @@ public void initialize(String name, Map<String, String> unresolved) {
ConfigResponse config;
OAuthTokenResponse authResponse;
String credential = props.get(OAuth2Properties.CREDENTIAL);
// TODO: if scope can be overridden, it should be done consistently
String scope = props.getOrDefault(OAuth2Properties.SCOPE, OAuth2Properties.CATALOG_SCOPE);
try (RESTClient initClient = clientBuilder.apply(props)) {
Map<String, String> initHeaders =
RESTUtil.merge(configHeaders(props), OAuth2Util.authHeaders(initToken));
if (credential != null && !credential.isEmpty()) {
authResponse =
OAuth2Util.fetchToken(
initClient, initHeaders, credential, OAuth2Properties.CATALOG_SCOPE);
authResponse = OAuth2Util.fetchToken(initClient, initHeaders, credential, scope);
Map<String, String> authHeaders =
RESTUtil.merge(initHeaders, OAuth2Util.authHeaders(authResponse.token()));
config = fetchConfig(initClient, authHeaders, props);
Expand All @@ -166,7 +164,7 @@ public void initialize(String name, Map<String, String> unresolved) {
this.client = clientBuilder.apply(mergedProps);
this.paths = ResourcePaths.forCatalogProperties(mergedProps);

this.catalogAuth = new AuthSession(baseHeaders, null, null, credential);
this.catalogAuth = new AuthSession(baseHeaders, null, null, credential, scope);
if (authResponse != null) {
this.catalogAuth =
AuthSession.fromTokenResponse(
Expand Down Expand Up @@ -774,23 +772,14 @@ private AuthSession newSession(
if (credentials.containsKey(OAuth2Properties.CREDENTIAL)) {
// fetch a token using the client credentials flow
return AuthSession.fromCredential(
client,
tokenRefreshExecutor(),
credentials.get(OAuth2Properties.CREDENTIAL),
parent,
OAuth2Properties.CATALOG_SCOPE);
client, tokenRefreshExecutor(), credentials.get(OAuth2Properties.CREDENTIAL), parent);
}

for (String tokenType : TOKEN_PREFERENCE_ORDER) {
if (credentials.containsKey(tokenType)) {
// exchange the token for an access token using the token exchange flow
return AuthSession.fromTokenExchange(
client,
tokenRefreshExecutor(),
credentials.get(tokenType),
tokenType,
parent,
OAuth2Properties.CATALOG_SCOPE);
client, tokenRefreshExecutor(), credentials.get(tokenType), tokenType, parent);
}
}
}
Expand Down
45 changes: 31 additions & 14 deletions core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -359,24 +359,30 @@ public static class AuthSession {
private String tokenType;
private Long expiresAtMillis;
private final String credential;
private final String scope;
private volatile boolean keepRefreshed = true;

/**
* @deprecated will be removed in 1.3.0; use {@link AuthSession#AuthSession(Map, String, String,
* String)} instead.
* String, String)} instead.
*/
@Deprecated
public AuthSession(Map<String, String> baseHeaders, String token, String tokenType) {
this(baseHeaders, token, tokenType, null);
this(baseHeaders, token, tokenType, null, OAuth2Properties.CATALOG_SCOPE);
}

public AuthSession(
Map<String, String> baseHeaders, String token, String tokenType, String credential) {
Map<String, String> baseHeaders,
String token,
String tokenType,
String credential,
String scope) {
this.headers = RESTUtil.merge(baseHeaders, authHeaders(token));
this.token = token;
this.tokenType = tokenType;
this.expiresAtMillis = OAuth2Util.expiresAtMillis(token);
this.credential = credential;
this.scope = scope;
}

public Map<String, String> headers() {
Expand All @@ -395,6 +401,10 @@ public Long expiresAtMillis() {
return expiresAtMillis;
}

public String scope() {
return scope;
}

public void stopRefreshing() {
this.keepRefreshed = false;
}
Expand All @@ -414,7 +424,7 @@ static void setTokenRefreshNumRetries(int retries) {
* @return A new {@link AuthSession} with empty headers.
*/
public static AuthSession empty() {
return new AuthSession(ImmutableMap.of(), null, null, null);
return new AuthSession(ImmutableMap.of(), null, null, null, OAuth2Properties.CATALOG_SCOPE);
}

/**
Expand Down Expand Up @@ -469,7 +479,7 @@ private OAuthTokenResponse refreshCurrentToken(RESTClient client) {
return refreshExpiredToken(client);
} else {
// attempt a normal refresh
return refreshToken(client, headers(), token, tokenType, OAuth2Properties.CATALOG_SCOPE);
return refreshToken(client, headers(), token, tokenType, scope);
}
}

Expand All @@ -480,7 +490,7 @@ static boolean isExpired(Long expiresAtMillis, long now) {
private OAuthTokenResponse refreshExpiredToken(RESTClient client) {
if (credential != null) {
Map<String, String> basicHeaders = RESTUtil.merge(headers(), basicAuthHeaders(credential));
return refreshToken(client, basicHeaders, token, tokenType, OAuth2Properties.CATALOG_SCOPE);
return refreshToken(client, basicHeaders, token, tokenType, scope);
}

return null;
Expand Down Expand Up @@ -566,7 +576,11 @@ public static AuthSession fromAccessToken(
AuthSession parent) {
AuthSession session =
new AuthSession(
parent.headers(), token, OAuth2Properties.ACCESS_TOKEN_TYPE, parent.credential());
parent.headers(),
token,
OAuth2Properties.ACCESS_TOKEN_TYPE,
parent.credential(),
parent.scope());

long startTimeMillis = System.currentTimeMillis();
Long expiresAtMillis = session.expiresAtMillis();
Expand Down Expand Up @@ -596,10 +610,10 @@ public static AuthSession fromCredential(
RESTClient client,
ScheduledExecutorService executor,
String credential,
AuthSession parent,
String scope) {
AuthSession parent) {
long startTimeMillis = System.currentTimeMillis();
OAuthTokenResponse response = fetchToken(client, parent.headers(), credential, scope);
OAuthTokenResponse response =
fetchToken(client, parent.headers(), credential, parent.scope());
return fromTokenResponse(client, executor, response, startTimeMillis, parent, credential);
}

Expand All @@ -622,7 +636,11 @@ private static AuthSession fromTokenResponse(
String credential) {
AuthSession session =
new AuthSession(
parent.headers(), response.token(), response.issuedTokenType(), credential);
parent.headers(),
response.token(),
response.issuedTokenType(),
credential,
parent.scope());
if (response.expiresInSeconds() != null) {
scheduleTokenRefresh(
client,
Expand All @@ -644,8 +662,7 @@ public static AuthSession fromTokenExchange(
ScheduledExecutorService executor,
String token,
String tokenType,
AuthSession parent,
String scope) {
AuthSession parent) {
long startTimeMillis = System.currentTimeMillis();
OAuthTokenResponse response =
exchangeToken(
Expand All @@ -655,7 +672,7 @@ public static AuthSession fromTokenExchange(
tokenType,
parent.token(),
parent.tokenType(),
scope);
parent.scope());
return fromTokenResponse(client, executor, response, startTimeMillis, parent);
}
}
Expand Down
91 changes: 91 additions & 0 deletions core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod;
import org.apache.iceberg.rest.auth.AuthSessionUtil;
import org.apache.iceberg.rest.auth.OAuth2Properties;
import org.apache.iceberg.rest.auth.OAuth2Util;
import org.apache.iceberg.rest.responses.ConfigResponse;
import org.apache.iceberg.rest.responses.ErrorResponse;
Expand Down Expand Up @@ -1414,4 +1415,94 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh() throws Exc
eq(refreshedCatalogHeader),
any());
}

@Test
public void testCatalogWithCustomTokenScope() throws Exception {
Map<String, String> emptyHeaders = ImmutableMap.of();
Map<String, String> catalogHeaders =
ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog");

RESTCatalogAdapter adapter = Mockito.spy(new RESTCatalogAdapter(backendCatalog));

Answer<OAuthTokenResponse> addOneSecondExpiration =
invocation -> {
OAuthTokenResponse response = (OAuthTokenResponse) invocation.callRealMethod();
return OAuthTokenResponse.builder()
.withToken(response.token())
.withTokenType(response.tokenType())
.withIssuedTokenType(response.issuedTokenType())
.addScopes(response.scopes())
.setExpirationInSeconds(1)
.build();
};

Mockito.doAnswer(addOneSecondExpiration)
.when(adapter)
.execute(
eq(HTTPMethod.POST),
eq("v1/oauth/tokens"),
any(),
any(),
eq(OAuthTokenResponse.class),
any(),
any());

Map<String, String> contextCredentials = ImmutableMap.of();
SessionCatalog.SessionContext context =
new SessionCatalog.SessionContext(
UUID.randomUUID().toString(), "user", contextCredentials, ImmutableMap.of());

RESTCatalog catalog = new RESTCatalog(context, (config) -> adapter);
String scope = "custom_catalog_scope";
catalog.initialize(
"prod",
ImmutableMap.of(
CatalogProperties.URI,
"ignored",
"credential",
"catalog:secret",
OAuth2Properties.SCOPE,
scope));

Thread.sleep(1_100);

// call client credentials with no initial auth
Mockito.verify(adapter)
.execute(
eq(HTTPMethod.POST),
eq("v1/oauth/tokens"),
any(),
any(),
eq(OAuthTokenResponse.class),
eq(emptyHeaders),
any());

// use the client credential token for config
Mockito.verify(adapter)
.execute(
eq(HTTPMethod.GET),
eq("v1/config"),
any(),
any(),
eq(ConfigResponse.class),
eq(catalogHeaders),
any());

// verify the token exchange uses the right scope
Map<String, String> firstRefreshRequest =
ImmutableMap.of(
"grant_type", "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token", "client-credentials-token:sub=catalog",
"subject_token_type", "urn:ietf:params:oauth:token-type:access_token",
"scope", scope);
Mockito.verify(adapter)
.execute(
eq(HTTPMethod.POST),
eq("v1/oauth/tokens"),
any(),
Mockito.argThat(firstRefreshRequest::equals),
eq(OAuthTokenResponse.class),
eq(catalogHeaders),
any());
}
}

0 comments on commit 909d3bb

Please sign in to comment.