Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core: Allow customizing OAuth scope #6616

Merged
merged 1 commit into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,12 @@ public void initialize(String name, Map<String, String> unresolved) {
ConfigResponse config;
OAuthTokenResponse authResponse;
String credential = props.get(OAuth2Properties.CREDENTIAL);
// TODO: if scope can be overridden, it should be done consistently
String scope = props.getOrDefault(OAuth2Properties.SCOPE, OAuth2Properties.CATALOG_SCOPE);
nastra marked this conversation as resolved.
Show resolved Hide resolved
try (RESTClient initClient = clientBuilder.apply(props)) {
Map<String, String> initHeaders =
RESTUtil.merge(configHeaders(props), OAuth2Util.authHeaders(initToken));
if (credential != null && !credential.isEmpty()) {
authResponse =
OAuth2Util.fetchToken(
initClient, initHeaders, credential, OAuth2Properties.CATALOG_SCOPE);
authResponse = OAuth2Util.fetchToken(initClient, initHeaders, credential, scope);
Map<String, String> authHeaders =
RESTUtil.merge(initHeaders, OAuth2Util.authHeaders(authResponse.token()));
config = fetchConfig(initClient, authHeaders, props);
Expand All @@ -166,7 +164,7 @@ public void initialize(String name, Map<String, String> unresolved) {
this.client = clientBuilder.apply(mergedProps);
this.paths = ResourcePaths.forCatalogProperties(mergedProps);

this.catalogAuth = new AuthSession(baseHeaders, null, null, credential);
this.catalogAuth = new AuthSession(baseHeaders, null, null, credential, scope);
if (authResponse != null) {
this.catalogAuth =
AuthSession.fromTokenResponse(
Expand Down Expand Up @@ -774,23 +772,14 @@ private AuthSession newSession(
if (credentials.containsKey(OAuth2Properties.CREDENTIAL)) {
// fetch a token using the client credentials flow
return AuthSession.fromCredential(
client,
tokenRefreshExecutor(),
credentials.get(OAuth2Properties.CREDENTIAL),
parent,
OAuth2Properties.CATALOG_SCOPE);
client, tokenRefreshExecutor(), credentials.get(OAuth2Properties.CREDENTIAL), parent);
}

for (String tokenType : TOKEN_PREFERENCE_ORDER) {
if (credentials.containsKey(tokenType)) {
// exchange the token for an access token using the token exchange flow
return AuthSession.fromTokenExchange(
client,
tokenRefreshExecutor(),
credentials.get(tokenType),
tokenType,
parent,
OAuth2Properties.CATALOG_SCOPE);
client, tokenRefreshExecutor(), credentials.get(tokenType), tokenType, parent);
}
}
}
Expand Down
45 changes: 31 additions & 14 deletions core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -359,24 +359,30 @@ public static class AuthSession {
private String tokenType;
private Long expiresAtMillis;
private final String credential;
private final String scope;
private volatile boolean keepRefreshed = true;

/**
* @deprecated will be removed in 1.3.0; use {@link AuthSession#AuthSession(Map, String, String,
* String)} instead.
* String, String)} instead.
*/
@Deprecated
public AuthSession(Map<String, String> baseHeaders, String token, String tokenType) {
this(baseHeaders, token, tokenType, null);
this(baseHeaders, token, tokenType, null, OAuth2Properties.CATALOG_SCOPE);
}

public AuthSession(
Map<String, String> baseHeaders, String token, String tokenType, String credential) {
Map<String, String> baseHeaders,
String token,
String tokenType,
String credential,
String scope) {
this.headers = RESTUtil.merge(baseHeaders, authHeaders(token));
this.token = token;
this.tokenType = tokenType;
this.expiresAtMillis = OAuth2Util.expiresAtMillis(token);
this.credential = credential;
this.scope = scope;
}

public Map<String, String> headers() {
Expand All @@ -395,6 +401,10 @@ public Long expiresAtMillis() {
return expiresAtMillis;
}

public String scope() {
return scope;
}

public void stopRefreshing() {
this.keepRefreshed = false;
}
Expand All @@ -414,7 +424,7 @@ static void setTokenRefreshNumRetries(int retries) {
* @return A new {@link AuthSession} with empty headers.
*/
public static AuthSession empty() {
return new AuthSession(ImmutableMap.of(), null, null, null);
return new AuthSession(ImmutableMap.of(), null, null, null, OAuth2Properties.CATALOG_SCOPE);
}

/**
Expand Down Expand Up @@ -469,7 +479,7 @@ private OAuthTokenResponse refreshCurrentToken(RESTClient client) {
return refreshExpiredToken(client);
} else {
// attempt a normal refresh
return refreshToken(client, headers(), token, tokenType, OAuth2Properties.CATALOG_SCOPE);
return refreshToken(client, headers(), token, tokenType, scope);
}
}

Expand All @@ -480,7 +490,7 @@ static boolean isExpired(Long expiresAtMillis, long now) {
private OAuthTokenResponse refreshExpiredToken(RESTClient client) {
if (credential != null) {
Map<String, String> basicHeaders = RESTUtil.merge(headers(), basicAuthHeaders(credential));
return refreshToken(client, basicHeaders, token, tokenType, OAuth2Properties.CATALOG_SCOPE);
return refreshToken(client, basicHeaders, token, tokenType, scope);
}

return null;
Expand Down Expand Up @@ -566,7 +576,11 @@ public static AuthSession fromAccessToken(
AuthSession parent) {
AuthSession session =
new AuthSession(
parent.headers(), token, OAuth2Properties.ACCESS_TOKEN_TYPE, parent.credential());
parent.headers(),
token,
OAuth2Properties.ACCESS_TOKEN_TYPE,
parent.credential(),
parent.scope());

long startTimeMillis = System.currentTimeMillis();
Long expiresAtMillis = session.expiresAtMillis();
Expand Down Expand Up @@ -596,10 +610,10 @@ public static AuthSession fromCredential(
RESTClient client,
ScheduledExecutorService executor,
String credential,
AuthSession parent,
String scope) {
AuthSession parent) {
long startTimeMillis = System.currentTimeMillis();
OAuthTokenResponse response = fetchToken(client, parent.headers(), credential, scope);
OAuthTokenResponse response =
fetchToken(client, parent.headers(), credential, parent.scope());
return fromTokenResponse(client, executor, response, startTimeMillis, parent, credential);
}

Expand All @@ -622,7 +636,11 @@ private static AuthSession fromTokenResponse(
String credential) {
AuthSession session =
new AuthSession(
parent.headers(), response.token(), response.issuedTokenType(), credential);
parent.headers(),
response.token(),
response.issuedTokenType(),
credential,
parent.scope());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this is called in fromTokenExchange, the scope passed there is not passed in here. I think that this method needs to have a scope argument. Then the fromTokenResponse above can pass parent.scope() and fromTokenExchange can pass the scope it was passed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've simplified this to actually and removed the scope parameter from fromTokenExchange/fromCredential as I think we need to take the scope from the parent in those cases

if (response.expiresInSeconds() != null) {
scheduleTokenRefresh(
client,
Expand All @@ -644,8 +662,7 @@ public static AuthSession fromTokenExchange(
ScheduledExecutorService executor,
String token,
String tokenType,
AuthSession parent,
String scope) {
AuthSession parent) {
long startTimeMillis = System.currentTimeMillis();
OAuthTokenResponse response =
exchangeToken(
Expand All @@ -655,7 +672,7 @@ public static AuthSession fromTokenExchange(
tokenType,
parent.token(),
parent.tokenType(),
scope);
parent.scope());
return fromTokenResponse(client, executor, response, startTimeMillis, parent);
}
}
Expand Down
91 changes: 91 additions & 0 deletions core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod;
import org.apache.iceberg.rest.auth.AuthSessionUtil;
import org.apache.iceberg.rest.auth.OAuth2Properties;
import org.apache.iceberg.rest.auth.OAuth2Util;
import org.apache.iceberg.rest.responses.ConfigResponse;
import org.apache.iceberg.rest.responses.ErrorResponse;
Expand Down Expand Up @@ -1414,4 +1415,94 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh() throws Exc
eq(refreshedCatalogHeader),
any());
}

@Test
public void testCatalogWithCustomTokenScope() throws Exception {
Map<String, String> emptyHeaders = ImmutableMap.of();
Map<String, String> catalogHeaders =
ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog");

RESTCatalogAdapter adapter = Mockito.spy(new RESTCatalogAdapter(backendCatalog));

Answer<OAuthTokenResponse> addOneSecondExpiration =
invocation -> {
OAuthTokenResponse response = (OAuthTokenResponse) invocation.callRealMethod();
return OAuthTokenResponse.builder()
.withToken(response.token())
.withTokenType(response.tokenType())
.withIssuedTokenType(response.issuedTokenType())
.addScopes(response.scopes())
.setExpirationInSeconds(1)
.build();
};

Mockito.doAnswer(addOneSecondExpiration)
.when(adapter)
.execute(
eq(HTTPMethod.POST),
eq("v1/oauth/tokens"),
any(),
any(),
eq(OAuthTokenResponse.class),
any(),
any());

Map<String, String> contextCredentials = ImmutableMap.of();
SessionCatalog.SessionContext context =
new SessionCatalog.SessionContext(
UUID.randomUUID().toString(), "user", contextCredentials, ImmutableMap.of());

RESTCatalog catalog = new RESTCatalog(context, (config) -> adapter);
String scope = "custom_catalog_scope";
catalog.initialize(
"prod",
ImmutableMap.of(
CatalogProperties.URI,
"ignored",
"credential",
"catalog:secret",
OAuth2Properties.SCOPE,
scope));

Thread.sleep(1_100);

// call client credentials with no initial auth
Mockito.verify(adapter)
.execute(
eq(HTTPMethod.POST),
eq("v1/oauth/tokens"),
any(),
any(),
eq(OAuthTokenResponse.class),
eq(emptyHeaders),
any());

// use the client credential token for config
Mockito.verify(adapter)
.execute(
eq(HTTPMethod.GET),
eq("v1/config"),
any(),
any(),
eq(ConfigResponse.class),
eq(catalogHeaders),
any());

// verify the token exchange uses the right scope
Map<String, String> firstRefreshRequest =
ImmutableMap.of(
"grant_type", "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token", "client-credentials-token:sub=catalog",
"subject_token_type", "urn:ietf:params:oauth:token-type:access_token",
"scope", scope);
Mockito.verify(adapter)
.execute(
eq(HTTPMethod.POST),
eq("v1/oauth/tokens"),
any(),
Mockito.argThat(firstRefreshRequest::equals),
eq(OAuthTokenResponse.class),
eq(catalogHeaders),
any());
}
}