Skip to content

Commit

Permalink
Update authc failure headers on license change (#61734) (#62442)
Browse files Browse the repository at this point in the history
Backport of #61734
  • Loading branch information
BigPandaToo authored Sep 16, 2020
1 parent 8d89a28 commit 167172a
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
* response headers like 'WWW-Authenticate'
*/
public class DefaultAuthenticationFailureHandler implements AuthenticationFailureHandler {
private final Map<String, List<String>> defaultFailureResponseHeaders;
private volatile Map<String, List<String>> defaultFailureResponseHeaders;

/**
* Constructs default authentication failure handler with provided default
Expand Down Expand Up @@ -55,6 +55,15 @@ public DefaultAuthenticationFailureHandler(final Map<String, List<String>> failu
}
}

/**
* This method is called when failureResponseHeaders need to be set (at startup) or updated (if license state changes)
*
* @param failureResponseHeaders the Map of failure response headers to be set
*/
public void setHeaders(Map<String, List<String>> failureResponseHeaders){
defaultFailureResponseHeaders = failureResponseHeaders;
}

/**
* For given 'WWW-Authenticate' header value returns the priority based on
* the auth-scheme. Lower number denotes more secure and preferred
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,32 +577,40 @@ private AuthenticationFailureHandler createAuthenticationFailureHandler(final Re
}
if (failureHandler == null) {
logger.debug("Using default authentication failure handler");
final Map<String, List<String>> defaultFailureResponseHeaders = new HashMap<>();
realms.asList().stream().forEach((realm) -> {
Map<String, List<String>> realmFailureHeaders = realm.getAuthenticationFailureHeaders();
realmFailureHeaders.entrySet().stream().forEach((e) -> {
String key = e.getKey();
e.getValue().stream()
.filter(v -> defaultFailureResponseHeaders.computeIfAbsent(key, x -> new ArrayList<>()).contains(v) == false)
.forEach(v -> defaultFailureResponseHeaders.get(key).add(v));
Supplier<Map<String, List<String>>> headersSupplier = () -> {
final Map<String, List<String>> defaultFailureResponseHeaders = new HashMap<>();
realms.asList().stream().forEach((realm) -> {
Map<String, List<String>> realmFailureHeaders = realm.getAuthenticationFailureHeaders();
realmFailureHeaders.entrySet().stream().forEach((e) -> {
String key = e.getKey();
e.getValue().stream()
.filter(v -> defaultFailureResponseHeaders.computeIfAbsent(key, x -> new ArrayList<>()).contains(v)
== false)
.forEach(v -> defaultFailureResponseHeaders.get(key).add(v));
});
});
});

if (TokenService.isTokenServiceEnabled(settings)) {
String bearerScheme = "Bearer realm=\"" + XPackField.SECURITY + "\"";
if (defaultFailureResponseHeaders.computeIfAbsent("WWW-Authenticate", x -> new ArrayList<>())
.contains(bearerScheme) == false) {
defaultFailureResponseHeaders.get("WWW-Authenticate").add(bearerScheme);
if (TokenService.isTokenServiceEnabled(settings)) {
String bearerScheme = "Bearer realm=\"" + XPackField.SECURITY + "\"";
if (defaultFailureResponseHeaders.computeIfAbsent("WWW-Authenticate", x -> new ArrayList<>())
.contains(bearerScheme) == false) {
defaultFailureResponseHeaders.get("WWW-Authenticate").add(bearerScheme);
}
}
}
if (API_KEY_SERVICE_ENABLED_SETTING.get(settings)) {
final String apiKeyScheme = "ApiKey";
if (defaultFailureResponseHeaders.computeIfAbsent("WWW-Authenticate", x -> new ArrayList<>())
.contains(apiKeyScheme) == false) {
defaultFailureResponseHeaders.get("WWW-Authenticate").add(apiKeyScheme);
if (API_KEY_SERVICE_ENABLED_SETTING.get(settings)) {
final String apiKeyScheme = "ApiKey";
if (defaultFailureResponseHeaders.computeIfAbsent("WWW-Authenticate", x -> new ArrayList<>())
.contains(apiKeyScheme) == false) {
defaultFailureResponseHeaders.get("WWW-Authenticate").add(apiKeyScheme);
}
}
}
failureHandler = new DefaultAuthenticationFailureHandler(defaultFailureResponseHeaders);
return defaultFailureResponseHeaders;
};
DefaultAuthenticationFailureHandler finalDefaultFailureHandler = new DefaultAuthenticationFailureHandler(headersSupplier.get());
failureHandler = finalDefaultFailureHandler;
getLicenseState().addListener(() -> {
finalDefaultFailureHandler.setHeaders(headersSupplier.get());
});
} else {
logger.debug("Using authentication failure handler from extension [" + extensionName + "]");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.security;

import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.Version;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterName;
Expand Down Expand Up @@ -33,6 +34,7 @@
import org.elasticsearch.test.VersionUtils;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.watcher.ResourceWatcherService;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.security.SecurityExtension;
import org.elasticsearch.xpack.core.security.SecurityField;
Expand Down Expand Up @@ -72,6 +74,8 @@
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand All @@ -94,18 +98,7 @@ public Map<String, Realm.Factory> getRealms(SecurityComponents components) {
}
}

private Collection<Object> createComponents(Settings testSettings, SecurityExtension... extensions) throws Exception {
if (security != null) {
throw new IllegalStateException("Security object already exists (" + security + ")");
}
Settings.Builder builder = Settings.builder()
.put("xpack.security.enabled", true)
.put(testSettings)
.put("path.home", createTempDir());
if (inFipsJvm()) {
builder.put(XPackSettings.DIAGNOSE_TRUST_EXCEPTIONS_SETTING.getKey(), false);
}
Settings settings = builder.build();
private Collection<Object> createComponentsUtil(Settings settings, SecurityExtension... extensions) throws Exception {
Environment env = TestEnvironment.newEnvironment(settings);
licenseState = new TestUtils.UpdatableLicenseState(settings);
SSLService sslService = new SSLService(settings, env);
Expand Down Expand Up @@ -137,6 +130,36 @@ protected SSLService getSslService() {
xContentRegistry(), env, new IndexNameExpressionResolver());
}

private Collection<Object> createComponentsWithSecurityNotExplicitlyEnabled(Settings testSettings, SecurityExtension... extensions)
throws Exception {
if (security != null) {
throw new IllegalStateException("Security object already exists (" + security + ")");
}
Settings.Builder builder = Settings.builder()
.put(testSettings)
.put("path.home", createTempDir());
if (inFipsJvm()) {
builder.put(XPackSettings.DIAGNOSE_TRUST_EXCEPTIONS_SETTING.getKey(), false);
}
Settings settings = builder.build();
return createComponentsUtil(settings, extensions);
}

private Collection<Object> createComponents(Settings testSettings, SecurityExtension... extensions) throws Exception {
if (security != null) {
throw new IllegalStateException("Security object already exists (" + security + ")");
}
Settings.Builder builder = Settings.builder()
.put("xpack.security.enabled", true)
.put(testSettings)
.put("path.home", createTempDir());
if (inFipsJvm()) {
builder.put(XPackSettings.DIAGNOSE_TRUST_EXCEPTIONS_SETTING.getKey(), false);
}
Settings settings = builder.build();
return createComponentsUtil(settings, extensions);
}

private static <T> T findComponent(Class<T> type, Collection<Object> components) {
for (Object obj : components) {
if (type.isInstance(obj)) {
Expand Down Expand Up @@ -490,4 +513,16 @@ public void testValidateForFipsNoErrors() {
Security.validateForFips(settings);
// no exception thrown
}

private void logAndFail(Exception e) {
logger.error("unexpected exception", e);
fail("unexpected exception " + e.getMessage());
}

private void VerifyBasicAuthenticationHeader(Exception e) {
assertThat(e, instanceOf(ElasticsearchSecurityException.class));
assertThat(((ElasticsearchSecurityException) e).getHeader("WWW-Authenticate"), notNullValue());
assertThat(((ElasticsearchSecurityException) e).getHeader("WWW-Authenticate"),
hasItem("Basic realm=\"" + XPackField.SECURITY + "\" charset=\"UTF-8\""));
}
}

0 comments on commit 167172a

Please sign in to comment.