Skip to content

Commit

Permalink
Core: Improve token exchange handling when token is expired (#6489)
Browse files Browse the repository at this point in the history
  • Loading branch information
nastra committed Jan 17, 2023
1 parent 32cfb62 commit ab1a19b
Show file tree
Hide file tree
Showing 4 changed files with 491 additions and 99 deletions.
127 changes: 38 additions & 89 deletions core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
import org.apache.iceberg.rest.responses.OAuthTokenResponse;
import org.apache.iceberg.rest.responses.UpdateNamespacePropertiesResponse;
import org.apache.iceberg.util.EnvironmentUtil;
import org.apache.iceberg.util.Pair;
import org.apache.iceberg.util.PropertyUtil;
import org.apache.iceberg.util.ThreadPools;
import org.slf4j.Logger;
Expand All @@ -89,8 +88,6 @@ public class RESTSessionCatalog extends BaseSessionCatalog
implements Configurable<Configuration>, Closeable {
private static final Logger LOG = LoggerFactory.getLogger(RESTSessionCatalog.class);
private static final String REST_METRICS_REPORTING_ENABLED = "rest-metrics-reporting-enabled";
private static final long MAX_REFRESH_WINDOW_MILLIS = 300_000; // 5 minutes
private static final long MIN_REFRESH_WAIT_MILLIS = 10;
private static final List<String> TOKEN_PREFERENCE_ORDER =
ImmutableList.of(
OAuth2Properties.ID_TOKEN_TYPE,
Expand Down Expand Up @@ -136,18 +133,18 @@ public void initialize(String name, Map<String, String> unresolved) {
// fetch auth and config to complete initialization
ConfigResponse config;
OAuthTokenResponse authResponse;
String credential = props.get(OAuth2Properties.CREDENTIAL);
// TODO: if scope can be overridden, it should be done consistently
try (RESTClient initClient = clientBuilder.apply(props)) {
Map<String, String> initHeaders =
RESTUtil.merge(configHeaders(props), OAuth2Util.authHeaders(initToken));
String credential = props.get(OAuth2Properties.CREDENTIAL);
if (credential != null && !credential.isEmpty()) {
String scope = props.getOrDefault(OAuth2Properties.SCOPE, OAuth2Properties.CATALOG_SCOPE);
authResponse = OAuth2Util.fetchToken(initClient, initHeaders, credential, scope);
config =
fetchConfig(
initClient,
RESTUtil.merge(initHeaders, OAuth2Util.authHeaders(authResponse.token())),
props);
authResponse =
OAuth2Util.fetchToken(
initClient, initHeaders, credential, OAuth2Properties.CATALOG_SCOPE);
Map<String, String> authHeaders =
RESTUtil.merge(initHeaders, OAuth2Util.authHeaders(authResponse.token()));
config = fetchConfig(initClient, authHeaders, props);
} else {
authResponse = null;
config = fetchConfig(initClient, initHeaders, props);
Expand All @@ -159,12 +156,6 @@ public void initialize(String name, Map<String, String> unresolved) {
// build the final configuration and set up the catalog's auth
Map<String, String> mergedProps = config.merge(props);
Map<String, String> baseHeaders = configHeaders(mergedProps);
this.catalogAuth = new AuthSession(baseHeaders, null, null);
if (authResponse != null) {
this.catalogAuth = newSession(authResponse, startTimeMillis, catalogAuth);
} else if (initToken != null) {
this.catalogAuth = newSession(initToken, expiresInMs(mergedProps), catalogAuth);
}

this.sessions = newSessionCache(mergedProps);
this.refreshAuthByDefault =
Expand All @@ -175,6 +166,17 @@ public void initialize(String name, Map<String, String> unresolved) {
this.client = clientBuilder.apply(mergedProps);
this.paths = ResourcePaths.forCatalogProperties(mergedProps);

this.catalogAuth = new AuthSession(baseHeaders, null, null, credential);
if (authResponse != null) {
this.catalogAuth =
AuthSession.fromTokenResponse(
client, tokenRefreshExecutor(), authResponse, startTimeMillis, catalogAuth);
} else if (initToken != null) {
this.catalogAuth =
AuthSession.fromAccessToken(
client, tokenRefreshExecutor(), initToken, expiresInMs(mergedProps), catalogAuth);
}

String ioImpl = mergedProps.get(CatalogProperties.FILE_IO_IMPL);
this.io =
CatalogUtil.loadFileIO(
Expand Down Expand Up @@ -447,34 +449,6 @@ private ScheduledExecutorService tokenRefreshExecutor() {
return refreshExecutor;
}

@SuppressWarnings("FutureReturnValueIgnored")
private void scheduleTokenRefresh(
AuthSession session, long startTimeMillis, long expiresIn, TimeUnit unit) {
// convert expiration interval to milliseconds
long expiresInMillis = unit.toMillis(expiresIn);
// how much ahead of time to start the request to allow it to complete
long refreshWindowMillis = Math.min(expiresInMillis / 10, MAX_REFRESH_WINDOW_MILLIS);
// how much time to wait before expiration
long waitIntervalMillis = expiresInMillis - refreshWindowMillis;
// how much time has already elapsed since the new token was issued
long elapsedMillis = System.currentTimeMillis() - startTimeMillis;
// how much time to actually wait
long timeToWait = Math.max(waitIntervalMillis - elapsedMillis, MIN_REFRESH_WAIT_MILLIS);

tokenRefreshExecutor()
.schedule(
() -> {
long refreshStartTime = System.currentTimeMillis();
Pair<Integer, TimeUnit> expiration = session.refresh(client);
if (expiration != null) {
scheduleTokenRefresh(
session, refreshStartTime, expiration.first(), expiration.second());
}
},
timeToWait,
TimeUnit.MILLISECONDS);
}

@Override
public void close() throws IOException {
shutdownRefreshExecutor();
Expand Down Expand Up @@ -789,66 +763,41 @@ private AuthSession newSession(
if (credentials != null) {
// use the bearer token without exchanging
if (credentials.containsKey(OAuth2Properties.TOKEN)) {
return newSession(credentials.get(OAuth2Properties.TOKEN), expiresInMs(properties), parent);
return AuthSession.fromAccessToken(
client,
tokenRefreshExecutor(),
credentials.get(OAuth2Properties.TOKEN),
expiresInMs(properties),
parent);
}

if (credentials.containsKey(OAuth2Properties.CREDENTIAL)) {
// fetch a token using the client credentials flow
return newSession(credentials.get(OAuth2Properties.CREDENTIAL), parent);
return AuthSession.fromCredential(
client,
tokenRefreshExecutor(),
credentials.get(OAuth2Properties.CREDENTIAL),
parent,
OAuth2Properties.CATALOG_SCOPE);
}

for (String tokenType : TOKEN_PREFERENCE_ORDER) {
if (credentials.containsKey(tokenType)) {
// exchange the token for an access token using the token exchange flow
return newSession(credentials.get(tokenType), tokenType, parent);
return AuthSession.fromTokenExchange(
client,
tokenRefreshExecutor(),
credentials.get(tokenType),
tokenType,
parent,
OAuth2Properties.CATALOG_SCOPE);
}
}
}

return parent;
}

private AuthSession newSession(String token, Long expirationMs, AuthSession parent) {
AuthSession session =
new AuthSession(parent.headers(), token, OAuth2Properties.ACCESS_TOKEN_TYPE);
if (expirationMs != null) {
scheduleTokenRefresh(
session, System.currentTimeMillis(), expirationMs, TimeUnit.MILLISECONDS);
}
return session;
}

private AuthSession newSession(String token, String tokenType, AuthSession parent) {
long startTimeMillis = System.currentTimeMillis();
OAuthTokenResponse response =
OAuth2Util.exchangeToken(
client,
parent.headers(),
token,
tokenType,
parent.token(),
parent.tokenType(),
OAuth2Properties.CATALOG_SCOPE);
return newSession(response, startTimeMillis, parent);
}

private AuthSession newSession(String credential, AuthSession parent) {
long startTimeMillis = System.currentTimeMillis();
OAuthTokenResponse response =
OAuth2Util.fetchToken(client, parent.headers(), credential, OAuth2Properties.CATALOG_SCOPE);
return newSession(response, startTimeMillis, parent);
}

private AuthSession newSession(
OAuthTokenResponse response, long startTimeMillis, AuthSession parent) {
AuthSession session =
new AuthSession(parent.headers(), response.token(), response.issuedTokenType());
if (response.expiresInSeconds() != null) {
scheduleTokenRefresh(session, startTimeMillis, response.expiresInSeconds(), TimeUnit.SECONDS);
}
return session;
}

private Long expiresInMs(Map<String, String> properties) {
if (refreshAuthByDefault || properties.containsKey(OAuth2Properties.TOKEN_EXPIRES_IN_MS)) {
return PropertyUtil.propertyAsLong(
Expand Down
Loading

0 comments on commit ab1a19b

Please sign in to comment.