Skip to content

Commit

Permalink
Fixed JWT principal from claims (elastic#101333)
Browse files Browse the repository at this point in the history
This changes the format of a JWT's principal before the JWT is actually validated by any JWT realm.
The JWT's principal is a convenient way to refer to a JWT that has not yet been verified by a JWT realm.
The JWT's principal is printed in the audit and regular logs (notably for auditing authn failures),
as well as the smart realm chain reordering optimization.
The JWT principal is NOT required to be identical to the JWT-authenticated user's principal,
but, in general, they should be similar.

Previously, the JWT's principal was built by individual realms in the same way the realms
built the authenticated user's principal. This had the advantage that,
in simpler JWT realms configurations (e.g. a single JWT realm in the chain),
the JWT principal and the authenticated user's principal are very similar.
However, the drawback is that, in general, the JWT principal and the user principal
can be very different (i.e. in the case where one JWT realm builds the JWT principal
and a different one builds the user principal).
Another downside is that the (unauthenticated) JWT principal depended on realm ordering,
which makes identifying the JWT from its principal dependent on the ES authn realm configuration.

This PR implements a consistent fixed logic to build the JWT principal,
which now only depends on the JWT's claims and no ES configuration.

Co-authored-by: Jake Landis jake.landis@elastic.co
  • Loading branch information
albertzaharovits authored Nov 9, 2023
1 parent 2ebd084 commit 6b72def
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 181 deletions.
29 changes: 29 additions & 0 deletions docs/changelog/101333.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
pr: 101333
summary: Fixed JWT principal from claims
area: Authorization
type: breaking
issues: []
breaking:
title: Fixed JWT principal from claims
area: Authorization
details: "This changes the format of a JWT's principal before the JWT is actually\
\ validated by any JWT realm. The JWT's principal is a convenient way to refer\
\ to a JWT that has not yet been verified by a JWT realm. The JWT's principal\
\ is printed in the audit and regular logs (notably for auditing authn failures)\
\ as well as the smart realm chain reordering optimization. The JWT principal\
\ is NOT required to be identical to the JWT-authenticated user's principal, but\
\ in general, they should be similar. Previously, the JWT's principal was built\
\ by individual realms in the same way the realms built the authenticated user's\
\ principal. This had the advantage that, in simpler JWT realms configurations\
\ (e.g. a single JWT realm in the chain), the JWT principal and the authenticated\
\ user's principal are very similar. However the drawback is that, in general,\
\ the JWT principal and the user principal can be very different (i.e. in the\
\ case where one JWT realm builds the JWT principal and a different one builds\
\ the user principal). Another downside is that the (unauthenticated) JWT principal\
\ depended on realm ordering, which makes identifying the JWT from its principal\
\ dependent on the ES authn realm configuration. This PR implements a consistent\
\ fixed logic to build the JWT principal, which only depends on the JWT's claims\
\ and no ES configuration."
impact: "Users will observe changed format and values for the `user.name` attribute\
\ of `authentication_failed` audit log events, in the JWT (failed) authn case."
notable: false
Original file line number Diff line number Diff line change
Expand Up @@ -188,53 +188,56 @@ public void testInvalidJWTDoesNotFallbackToAnonymousAccess() throws Exception {
}

public void testAnyJwtRealmWillExtractTheToken() throws ParseException {
final List<JwtRealm> jwtRealms = getJwtRealms();
final JwtRealm jwtRealm = randomFrom(jwtRealms);

final String sharedSecret = randomBoolean() ? randomAlphaOfLengthBetween(10, 20) : null;
final String iss = randomAlphaOfLengthBetween(5, 18);
final String aud = randomAlphaOfLengthBetween(5, 18);
final String sub = randomAlphaOfLengthBetween(5, 18);

// Realm 1 will extract the token because the JWT has all iss, sub, aud, principal claims.
// Their values do not match what realm 1 expects but that does not matter when extracting the token
final SignedJWT signedJWT1 = getSignedJWT(Map.of("iss", iss, "aud", aud, "sub", sub));
final ThreadContext threadContext1 = prepareThreadContext(signedJWT1, sharedSecret);
final var token1 = (JwtAuthenticationToken) jwtRealm.token(threadContext1);
final String principal1 = Strings.format("%s/%s/%s/%s", iss, aud, sub, sub);
assertJwtToken(token1, principal1, sharedSecret, signedJWT1);

// Realm 2 for extracting the token from the following JWT
// Because it does not have the sub claim but client_id, which is configured as fallback by realm 2
final String appId = randomAlphaOfLengthBetween(5, 18);
final SignedJWT signedJWT2 = getSignedJWT(Map.of("iss", iss, "aud", aud, "client_id", sub, "appid", appId));
final ThreadContext threadContext2 = prepareThreadContext(signedJWT2, sharedSecret);
final var token2 = (JwtAuthenticationToken) jwtRealm.token(threadContext2);
final String principal2 = Strings.format("%s/%s/%s/%s", iss, aud, sub, appId);
assertJwtToken(token2, principal2, sharedSecret, signedJWT2);

// Realm 3 will extract the token from the following JWT
// Because it has the oid claim which is configured as a fallback by realm 3
final String email = randomAlphaOfLengthBetween(5, 18) + "@example.com";
final SignedJWT signedJWT3 = getSignedJWT(Map.of("iss", iss, "aud", aud, "oid", sub, "email", email));
final ThreadContext threadContext3 = prepareThreadContext(signedJWT3, sharedSecret);
final var token3 = (JwtAuthenticationToken) jwtRealm.token(threadContext3);
final String principal3 = Strings.format("%s/%s/%s/%s", iss, aud, sub, email);
assertJwtToken(token3, principal3, sharedSecret, signedJWT3);

// The JWT does not match any realm's configuration, a token with generic token principal will be extracted
final SignedJWT signedJWT4 = getSignedJWT(Map.of("iss", iss, "aud", aud, "azp", sub, "email", email));
final ThreadContext threadContext4 = prepareThreadContext(signedJWT4, sharedSecret);
final var token4 = (JwtAuthenticationToken) jwtRealm.token(threadContext4);
final String principal4 = Strings.format("<unrecognized-jwt> by %s", iss);
assertJwtToken(token4, principal4, sharedSecret, signedJWT4);

// The JWT does not have an issuer, a token with generic token principal will be extracted
final SignedJWT signedJWT5 = getSignedJWT(Map.of("aud", aud, "sub", sub));
final ThreadContext threadContext5 = prepareThreadContext(signedJWT5, sharedSecret);
final var token5 = (JwtAuthenticationToken) jwtRealm.token(threadContext5);
final String principal5 = "<unrecognized-jwt>";
assertJwtToken(token5, principal5, sharedSecret, signedJWT5);
for (JwtRealm jwtRealm : getJwtRealms()) {
final String sharedSecret = randomBoolean() ? randomAlphaOfLengthBetween(10, 20) : null;
final String iss = randomAlphaOfLengthBetween(5, 18);
final List<String> aud = List.of(randomAlphaOfLengthBetween(5, 18), randomAlphaOfLengthBetween(5, 18));
final String sub = randomAlphaOfLengthBetween(5, 18);

// JWT1 has all iss, sub, aud, principal claims.
final SignedJWT signedJWT1 = getSignedJWT(Map.of("iss", iss, "aud", aud, "sub", sub));
final ThreadContext threadContext1 = prepareThreadContext(signedJWT1, sharedSecret);
final var token1 = (JwtAuthenticationToken) jwtRealm.token(threadContext1);
final String principal1 = Strings.format("'aud:%s,%s' 'iss:%s' 'sub:%s'", aud.get(0), aud.get(1), iss, sub);
assertJwtToken(token1, principal1, sharedSecret, signedJWT1);

// JWT2, JWT3, and JWT4 don't have the sub claim.
// Some realms define fallback claims for the sub claim (which themselves might not exist),
// but that is not relevant for token building (it's used for user principal assembling).
final String appId = randomAlphaOfLengthBetween(5, 18);
final SignedJWT signedJWT2 = getSignedJWT(Map.of("iss", iss, "aud", aud, "client_id", sub, "appid", appId));
final ThreadContext threadContext2 = prepareThreadContext(signedJWT2, sharedSecret);
final var token2 = (JwtAuthenticationToken) jwtRealm.token(threadContext2);
final String principal2 = Strings.format(
"'appid:%s' 'aud:%s,%s' 'client_id:%s' 'iss:%s'",
appId,
aud.get(0),
aud.get(1),
sub,
iss
);
assertJwtToken(token2, principal2, sharedSecret, signedJWT2);

final String email = randomAlphaOfLengthBetween(5, 18) + "@example.com";
final SignedJWT signedJWT3 = getSignedJWT(Map.of("iss", iss, "aud", aud, "oid", sub, "email", email));
final ThreadContext threadContext3 = prepareThreadContext(signedJWT3, sharedSecret);
final var token3 = (JwtAuthenticationToken) jwtRealm.token(threadContext3);
final String principal3 = Strings.format("'aud:%s,%s' 'email:%s' 'iss:%s' 'oid:%s'", aud.get(0), aud.get(1), email, iss, sub);
assertJwtToken(token3, principal3, sharedSecret, signedJWT3);

final SignedJWT signedJWT4 = getSignedJWT(Map.of("iss", iss, "aud", aud, "azp", sub, "email", email));
final ThreadContext threadContext4 = prepareThreadContext(signedJWT4, sharedSecret);
final var token4 = (JwtAuthenticationToken) jwtRealm.token(threadContext4);
final String principal4 = Strings.format("'aud:%s,%s' 'azp:%s' 'email:%s' 'iss:%s'", aud.get(0), aud.get(1), sub, email, iss);
assertJwtToken(token4, principal4, sharedSecret, signedJWT4);

// JWT5 does not have an issuer.
final SignedJWT signedJWT5 = getSignedJWT(Map.of("aud", aud, "sub", sub));
final ThreadContext threadContext5 = prepareThreadContext(signedJWT5, sharedSecret);
final var token5 = (JwtAuthenticationToken) jwtRealm.token(threadContext5);
final String principal5 = Strings.format("'aud:%s,%s' 'sub:%s'", aud.get(0), aud.get(1), sub);
assertJwtToken(token5, principal5, sharedSecret, signedJWT5);
}
}

public void testJwtRealmReturnsNullTokenWhenJwtCredentialIsAbsent() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,35 @@

import java.text.ParseException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.TreeSet;

/**
* An {@link AuthenticationToken} to hold JWT authentication related content.
*/
public class JwtAuthenticationToken implements AuthenticationToken {
private final String principal;
private SignedJWT signedJWT;
private final String principal;
private final byte[] userCredentialsHash;
@Nullable
private final SecureString clientAuthenticationSharedSecret;

/**
* Store a mandatory JWT and optional Shared Secret.
* @param principal The token's principal, useful as a realm order cache key
* @param signedJWT The JWT parsed from the end-user credentials
* @param userCredentialsHash The hash of the end-user credentials is used to compute the key for user cache at the realm level.
* See also {@link JwtRealm#authenticate}.
* @param clientAuthenticationSharedSecret URL-safe Shared Secret for Client authentication. Required by some JWT realms.
*/
public JwtAuthenticationToken(
String principal,
SignedJWT signedJWT,
byte[] userCredentialsHash,
@Nullable final SecureString clientAuthenticationSharedSecret
) {
this.principal = Objects.requireNonNull(principal);
this.signedJWT = Objects.requireNonNull(signedJWT);
this.principal = buildTokenPrincipal();
this.userCredentialsHash = Objects.requireNonNull(userCredentialsHash);

if ((clientAuthenticationSharedSecret != null) && (clientAuthenticationSharedSecret.isEmpty())) {
throw new IllegalArgumentException("Client shared secret must be non-empty");
}
Expand All @@ -70,7 +69,7 @@ public JWTClaimsSet getJWTClaimsSet() {
return signedJWT.getJWTClaimsSet();
} catch (ParseException e) {
assert false : "The JWT claims set should have already been successfully parsed before building the JWT authentication token";
throw new IllegalArgumentException(e);
throw new IllegalStateException(e);
}
}

Expand All @@ -95,4 +94,47 @@ public void clearCredentials() {
public String toString() {
return JwtAuthenticationToken.class.getSimpleName() + "=" + this.principal;
}

private String buildTokenPrincipal() {
JWTClaimsSet jwtClaimsSet = getJWTClaimsSet();
StringBuilder principalBuilder = new StringBuilder();
claimsLoop: for (String claimName : new TreeSet<>(jwtClaimsSet.getClaims().keySet())) {
Object claimValue = jwtClaimsSet.getClaim(claimName);
if (claimValue == null) {
continue;
}
// only use String or String[] claim values to assemble the principal
if (claimValue instanceof String) {
if (principalBuilder.isEmpty() == false) {
principalBuilder.append(' ');
}
principalBuilder.append('\'').append(claimName).append(':').append((String) claimValue).append('\'');
} else if (claimValue instanceof List<?>) {
List<?> claimValuesList = (List<?>) claimValue;
if (claimValuesList.isEmpty()) {
continue;
}
for (Object claimValueElem : claimValuesList) {
if (claimValueElem instanceof String == false) {
continue claimsLoop;
}
}
if (principalBuilder.isEmpty() == false) {
principalBuilder.append(' ');
}
principalBuilder.append('\'').append(claimName).append(':');
for (int i = 0; i < claimValuesList.size(); i++) {
if (i > 0) {
principalBuilder.append(',');
}
principalBuilder.append((String) claimValuesList.get(i));
}
principalBuilder.append('\'');
}
}
if (principalBuilder.isEmpty()) {
return "<unrecognized JWT token>";
}
return principalBuilder.toString();
}
}
Loading

0 comments on commit 6b72def

Please sign in to comment.