diff --git a/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java b/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java index f79dc9845cb7..b7ed1f864872 100644 --- a/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java +++ b/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java @@ -134,14 +134,12 @@ public void initialize(String name, Map unresolved) { ConfigResponse config; OAuthTokenResponse authResponse; String credential = props.get(OAuth2Properties.CREDENTIAL); - // TODO: if scope can be overridden, it should be done consistently + String scope = props.getOrDefault(OAuth2Properties.SCOPE, OAuth2Properties.CATALOG_SCOPE); try (RESTClient initClient = clientBuilder.apply(props)) { Map initHeaders = RESTUtil.merge(configHeaders(props), OAuth2Util.authHeaders(initToken)); if (credential != null && !credential.isEmpty()) { - authResponse = - OAuth2Util.fetchToken( - initClient, initHeaders, credential, OAuth2Properties.CATALOG_SCOPE); + authResponse = OAuth2Util.fetchToken(initClient, initHeaders, credential, scope); Map authHeaders = RESTUtil.merge(initHeaders, OAuth2Util.authHeaders(authResponse.token())); config = fetchConfig(initClient, authHeaders, props); @@ -166,7 +164,7 @@ public void initialize(String name, Map unresolved) { this.client = clientBuilder.apply(mergedProps); this.paths = ResourcePaths.forCatalogProperties(mergedProps); - this.catalogAuth = new AuthSession(baseHeaders, null, null, credential); + this.catalogAuth = new AuthSession(baseHeaders, null, null, credential, scope); if (authResponse != null) { this.catalogAuth = AuthSession.fromTokenResponse( @@ -774,23 +772,14 @@ private AuthSession newSession( if (credentials.containsKey(OAuth2Properties.CREDENTIAL)) { // fetch a token using the client credentials flow return AuthSession.fromCredential( - client, - tokenRefreshExecutor(), - credentials.get(OAuth2Properties.CREDENTIAL), - parent, - OAuth2Properties.CATALOG_SCOPE); + client, tokenRefreshExecutor(), credentials.get(OAuth2Properties.CREDENTIAL), parent); } for (String tokenType : TOKEN_PREFERENCE_ORDER) { if (credentials.containsKey(tokenType)) { // exchange the token for an access token using the token exchange flow return AuthSession.fromTokenExchange( - client, - tokenRefreshExecutor(), - credentials.get(tokenType), - tokenType, - parent, - OAuth2Properties.CATALOG_SCOPE); + client, tokenRefreshExecutor(), credentials.get(tokenType), tokenType, parent); } } } diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java index f2f07c3c910d..71cbbcb49da0 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java @@ -359,24 +359,30 @@ public static class AuthSession { private String tokenType; private Long expiresAtMillis; private final String credential; + private final String scope; private volatile boolean keepRefreshed = true; /** * @deprecated will be removed in 1.3.0; use {@link AuthSession#AuthSession(Map, String, String, - * String)} instead. + * String, String)} instead. */ @Deprecated public AuthSession(Map baseHeaders, String token, String tokenType) { - this(baseHeaders, token, tokenType, null); + this(baseHeaders, token, tokenType, null, OAuth2Properties.CATALOG_SCOPE); } public AuthSession( - Map baseHeaders, String token, String tokenType, String credential) { + Map baseHeaders, + String token, + String tokenType, + String credential, + String scope) { this.headers = RESTUtil.merge(baseHeaders, authHeaders(token)); this.token = token; this.tokenType = tokenType; this.expiresAtMillis = OAuth2Util.expiresAtMillis(token); this.credential = credential; + this.scope = scope; } public Map headers() { @@ -395,6 +401,10 @@ public Long expiresAtMillis() { return expiresAtMillis; } + public String scope() { + return scope; + } + public void stopRefreshing() { this.keepRefreshed = false; } @@ -414,7 +424,7 @@ static void setTokenRefreshNumRetries(int retries) { * @return A new {@link AuthSession} with empty headers. */ public static AuthSession empty() { - return new AuthSession(ImmutableMap.of(), null, null, null); + return new AuthSession(ImmutableMap.of(), null, null, null, OAuth2Properties.CATALOG_SCOPE); } /** @@ -469,7 +479,7 @@ private OAuthTokenResponse refreshCurrentToken(RESTClient client) { return refreshExpiredToken(client); } else { // attempt a normal refresh - return refreshToken(client, headers(), token, tokenType, OAuth2Properties.CATALOG_SCOPE); + return refreshToken(client, headers(), token, tokenType, scope); } } @@ -480,7 +490,7 @@ static boolean isExpired(Long expiresAtMillis, long now) { private OAuthTokenResponse refreshExpiredToken(RESTClient client) { if (credential != null) { Map basicHeaders = RESTUtil.merge(headers(), basicAuthHeaders(credential)); - return refreshToken(client, basicHeaders, token, tokenType, OAuth2Properties.CATALOG_SCOPE); + return refreshToken(client, basicHeaders, token, tokenType, scope); } return null; @@ -566,7 +576,11 @@ public static AuthSession fromAccessToken( AuthSession parent) { AuthSession session = new AuthSession( - parent.headers(), token, OAuth2Properties.ACCESS_TOKEN_TYPE, parent.credential()); + parent.headers(), + token, + OAuth2Properties.ACCESS_TOKEN_TYPE, + parent.credential(), + parent.scope()); long startTimeMillis = System.currentTimeMillis(); Long expiresAtMillis = session.expiresAtMillis(); @@ -596,10 +610,10 @@ public static AuthSession fromCredential( RESTClient client, ScheduledExecutorService executor, String credential, - AuthSession parent, - String scope) { + AuthSession parent) { long startTimeMillis = System.currentTimeMillis(); - OAuthTokenResponse response = fetchToken(client, parent.headers(), credential, scope); + OAuthTokenResponse response = + fetchToken(client, parent.headers(), credential, parent.scope()); return fromTokenResponse(client, executor, response, startTimeMillis, parent, credential); } @@ -622,7 +636,11 @@ private static AuthSession fromTokenResponse( String credential) { AuthSession session = new AuthSession( - parent.headers(), response.token(), response.issuedTokenType(), credential); + parent.headers(), + response.token(), + response.issuedTokenType(), + credential, + parent.scope()); if (response.expiresInSeconds() != null) { scheduleTokenRefresh( client, @@ -644,8 +662,7 @@ public static AuthSession fromTokenExchange( ScheduledExecutorService executor, String token, String tokenType, - AuthSession parent, - String scope) { + AuthSession parent) { long startTimeMillis = System.currentTimeMillis(); OAuthTokenResponse response = exchangeToken( @@ -655,7 +672,7 @@ public static AuthSession fromTokenExchange( tokenType, parent.token(), parent.tokenType(), - scope); + parent.scope()); return fromTokenResponse(client, executor, response, startTimeMillis, parent); } } diff --git a/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java b/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java index 45af20814a2c..0578bff29769 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java @@ -51,6 +51,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod; import org.apache.iceberg.rest.auth.AuthSessionUtil; +import org.apache.iceberg.rest.auth.OAuth2Properties; import org.apache.iceberg.rest.auth.OAuth2Util; import org.apache.iceberg.rest.responses.ConfigResponse; import org.apache.iceberg.rest.responses.ErrorResponse; @@ -1414,4 +1415,94 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh() throws Exc eq(refreshedCatalogHeader), any()); } + + @Test + public void testCatalogWithCustomTokenScope() throws Exception { + Map emptyHeaders = ImmutableMap.of(); + Map catalogHeaders = + ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); + + RESTCatalogAdapter adapter = Mockito.spy(new RESTCatalogAdapter(backendCatalog)); + + Answer addOneSecondExpiration = + invocation -> { + OAuthTokenResponse response = (OAuthTokenResponse) invocation.callRealMethod(); + return OAuthTokenResponse.builder() + .withToken(response.token()) + .withTokenType(response.tokenType()) + .withIssuedTokenType(response.issuedTokenType()) + .addScopes(response.scopes()) + .setExpirationInSeconds(1) + .build(); + }; + + Mockito.doAnswer(addOneSecondExpiration) + .when(adapter) + .execute( + eq(HTTPMethod.POST), + eq("v1/oauth/tokens"), + any(), + any(), + eq(OAuthTokenResponse.class), + any(), + any()); + + Map contextCredentials = ImmutableMap.of(); + SessionCatalog.SessionContext context = + new SessionCatalog.SessionContext( + UUID.randomUUID().toString(), "user", contextCredentials, ImmutableMap.of()); + + RESTCatalog catalog = new RESTCatalog(context, (config) -> adapter); + String scope = "custom_catalog_scope"; + catalog.initialize( + "prod", + ImmutableMap.of( + CatalogProperties.URI, + "ignored", + "credential", + "catalog:secret", + OAuth2Properties.SCOPE, + scope)); + + Thread.sleep(1_100); + + // call client credentials with no initial auth + Mockito.verify(adapter) + .execute( + eq(HTTPMethod.POST), + eq("v1/oauth/tokens"), + any(), + any(), + eq(OAuthTokenResponse.class), + eq(emptyHeaders), + any()); + + // use the client credential token for config + Mockito.verify(adapter) + .execute( + eq(HTTPMethod.GET), + eq("v1/config"), + any(), + any(), + eq(ConfigResponse.class), + eq(catalogHeaders), + any()); + + // verify the token exchange uses the right scope + Map firstRefreshRequest = + ImmutableMap.of( + "grant_type", "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token", "client-credentials-token:sub=catalog", + "subject_token_type", "urn:ietf:params:oauth:token-type:access_token", + "scope", scope); + Mockito.verify(adapter) + .execute( + eq(HTTPMethod.POST), + eq("v1/oauth/tokens"), + any(), + Mockito.argThat(firstRefreshRequest::equals), + eq(OAuthTokenResponse.class), + eq(catalogHeaders), + any()); + } }