Skip to content

Commit

Permalink
Check license state before preparing the request (elastic#98484)
Browse files Browse the repository at this point in the history
* Remove unused RestRequest parameter from license check

* Check the license state first

* Apply suggestions from code review

Co-authored-by: Yang Wang <ywangd@gmail.com>

* Add innerCheckFeatureAvailable

* Modify test

* Apply suggestions from code review

Co-authored-by: Yang Wang <ywangd@gmail.com>

* Remove unneeded mock

---------

Co-authored-by: Yang Wang <ywangd@gmail.com>
Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 22, 2023
1 parent 440b182 commit 366a9a2
Show file tree
Hide file tree
Showing 16 changed files with 55 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@ public List<Route> routes() {
}

@Override
protected Exception checkFeatureAvailable(RestRequest request) {
Exception failedFeature = super.checkFeatureAvailable(request);
if (failedFeature != null) {
return failedFeature;
} else if (Security.PKI_REALM_FEATURE.checkWithoutTracking(licenseState)) {
protected Exception innerCheckFeatureAvailable() {
if (Security.PKI_REALM_FEATURE.checkWithoutTracking(licenseState)) {
return null;
} else {
logger.info("The '{}' realm is not available under the current license", PkiRealmSettings.TYPE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,21 @@ protected SecurityBaseRestHandler(Settings settings, XPackLicenseState licenseSt
}

/**
* Calls the {@link #innerPrepareRequest(RestRequest, NodeClient)} method and then checks the
* license state. If the license state allows auth, the result from
* Calls the {@link #checkFeatureAvailable()} method to check whether the feature is available based
* on settings and license state. If allowed, the result from
* {@link #innerPrepareRequest(RestRequest, NodeClient)} is returned, otherwise a default error
* response will be returned indicating that security is not licensed.
*
* Note: the implementing rest handler is called before the license is checked so that we do not
* Note: If the license check fails we consume the request content and parameters so that we do not
* trip the unused parameters check
*/
protected final RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
RestChannelConsumer consumer = innerPrepareRequest(request, client);
final Exception failedFeature = checkFeatureAvailable(request);
final Exception failedFeature = checkFeatureAvailable();
if (failedFeature == null) {
return consumer;
return innerPrepareRequest(request, client);
} else {
request.params().keySet().forEach(key -> request.param(key, ""));
request.content();
return channel -> channel.sendResponse(new RestResponse(channel, failedFeature));
}
}
Expand All @@ -57,24 +58,31 @@ protected final RestChannelConsumer prepareRequest(RestRequest request, NodeClie
* Check whether the given request is allowed within the current license state and setup,
* and return the name of any unlicensed feature.
* By default this returns an exception if security is not enabled.
* Sub-classes can override this method if they have additional requirements.
* Sub-classes can override {@link #innerCheckFeatureAvailable()} if they have additional requirements.
*
* @return {@code null} if all required features are available, otherwise an exception to be
* sent to the requestor
* sent to the requester
*/
protected Exception checkFeatureAvailable(RestRequest request) {
public final Exception checkFeatureAvailable() {
if (XPackSettings.SECURITY_ENABLED.get(settings) == false) {
return new IllegalStateException("Security is not enabled but a security rest handler is registered");
} else {
return null;
return innerCheckFeatureAvailable();
}
}

/**
* Implementers should implement this method when sub-classes have additional license requirements.
*/
protected Exception innerCheckFeatureAvailable() {
return null;
}

/**
* Implementers should implement this method as they normally would for
* {@link BaseRestHandler#prepareRequest(RestRequest, NodeClient)} and ensure that all request
* parameters are consumed prior to returning a value. The returned value is not guaranteed to
* be executed unless security is licensed and all request parameters are known
* parameters are consumed prior to returning a value. This method is executed only if the
* check from {@link #checkFeatureAvailable()} passes.
*/
protected abstract RestChannelConsumer innerPrepareRequest(RestRequest request, NodeClient client) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler;
import org.elasticsearch.xpack.security.support.FeatureNotEnabledException;
Expand All @@ -20,11 +19,8 @@ abstract class ApiKeyBaseRestHandler extends SecurityBaseRestHandler {
}

@Override
protected Exception checkFeatureAvailable(RestRequest request) {
final Exception failedFeature = super.checkFeatureAvailable(request);
if (failedFeature != null) {
return failedFeature;
} else if (XPackSettings.API_KEY_SERVICE_ENABLED_SETTING.get(settings) == false) {
protected Exception innerCheckFeatureAvailable() {
if (XPackSettings.API_KEY_SERVICE_ENABLED_SETTING.get(settings) == false) {
return new FeatureNotEnabledException(FeatureNotEnabledException.Feature.API_KEY_SERVICE, "api keys are not enabled");
} else {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,8 @@ protected RestChannelConsumer innerPrepareRequest(final RestRequest request, fin
}

@Override
protected Exception checkFeatureAvailable(RestRequest request) {
final Exception failedFeature = super.checkFeatureAvailable(request);
if (failedFeature != null) {
return failedFeature;
} else if (ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE.checkWithoutTracking(licenseState)) {
protected Exception innerCheckFeatureAvailable() {
if (ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE.checkWithoutTracking(licenseState)) {
return null;
} else {
return LicenseUtils.newComplianceException(ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,8 @@ protected RestChannelConsumer innerPrepareRequest(final RestRequest request, fin
}

@Override
protected Exception checkFeatureAvailable(RestRequest request) {
final Exception failedFeature = super.checkFeatureAvailable(request);
if (failedFeature != null) {
return failedFeature;
} else if (ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE.checkWithoutTracking(licenseState)) {
protected Exception innerCheckFeatureAvailable() {
if (ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE.checkWithoutTracking(licenseState)) {
return null;
} else {
return LicenseUtils.newComplianceException(ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler;
Expand All @@ -25,11 +24,8 @@ public EnrollmentBaseRestHandler(Settings settings, XPackLicenseState licenseSta
}

@Override
protected Exception checkFeatureAvailable(RestRequest request) {
Exception failedFeature = super.checkFeatureAvailable(request);
if (failedFeature != null) {
return failedFeature;
} else if (XPackSettings.ENROLLMENT_ENABLED.get(settings) == false) {
protected Exception innerCheckFeatureAvailable() {
if (XPackSettings.ENROLLMENT_ENABLED.get(settings) == false) {
return new ElasticsearchSecurityException(
"Enrollment mode is not enabled. Set ["
+ XPackSettings.ENROLLMENT_ENABLED.getKey()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.xpack.security.Security;
import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler;

Expand All @@ -28,11 +27,8 @@ abstract class TokenBaseRestHandler extends SecurityBaseRestHandler {
}

@Override
protected Exception checkFeatureAvailable(RestRequest request) {
Exception failedFeature = super.checkFeatureAvailable(request);
if (failedFeature != null) {
return failedFeature;
} else if (Security.TOKEN_SERVICE_FEATURE.check(licenseState)) {
protected Exception innerCheckFeatureAvailable() {
if (Security.TOKEN_SERVICE_FEATURE.check(licenseState)) {
return null;
} else {
logger.info("Security tokens are not available under the current [{}] license", licenseState.getOperationMode().description());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.xpack.core.security.authc.oidc.OpenIdConnectRealmSettings;
import org.elasticsearch.xpack.security.authc.Realms;
import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler;
Expand All @@ -30,11 +29,8 @@ protected OpenIdConnectBaseRestHandler(Settings settings, XPackLicenseState lice
}

@Override
protected Exception checkFeatureAvailable(RestRequest request) {
Exception failedFeature = super.checkFeatureAvailable(request);
if (failedFeature != null) {
return failedFeature;
} else if (Realms.isRealmTypeAvailable(licenseState, OIDC_REALM_TYPE)) {
protected Exception innerCheckFeatureAvailable() {
if (Realms.isRealmTypeAvailable(licenseState, OIDC_REALM_TYPE)) {
return null;
} else {
logger.info("The '{}' realm is not available under the current license", OIDC_REALM_TYPE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,7 @@ protected RestChannelConsumer innerPrepareRequest(RestRequest request, NodeClien
}

@Override
protected Exception checkFeatureAvailable(RestRequest request) {
final Exception failedFeature = super.checkFeatureAvailable(request);
if (failedFeature != null) {
return failedFeature;
}
protected Exception innerCheckFeatureAvailable() {
if (Security.USER_PROFILE_COLLABORATION_FEATURE.check(licenseState)) {
return null;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.xpack.core.security.authc.saml.SamlRealmSettings;
import org.elasticsearch.xpack.security.authc.Realms;
import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler;
Expand All @@ -29,11 +28,8 @@ public SamlBaseRestHandler(Settings settings, XPackLicenseState licenseState) {
}

@Override
protected Exception checkFeatureAvailable(RestRequest request) {
Exception failedFeature = super.checkFeatureAvailable(request);
if (failedFeature != null) {
return failedFeature;
} else if (Realms.isRealmTypeAvailable(licenseState, SAML_REALM_TYPE)) {
protected Exception innerCheckFeatureAvailable() {
if (Realms.isRealmTypeAvailable(licenseState, SAML_REALM_TYPE)) {
return null;
} else {
logger.info("The '{}' realm is not available under the current license", SAML_REALM_TYPE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,7 @@ protected RestChannelConsumer innerPrepareRequest(RestRequest restRequest, NodeC
}

@Override
protected Exception checkFeatureAvailable(RestRequest request) {
final Exception failedFeature = super.checkFeatureAvailable(request);
if (failedFeature != null) {
return failedFeature;
}
protected Exception innerCheckFeatureAvailable() {
if (Security.USER_PROFILE_COLLABORATION_FEATURE.check(licenseState)) {
return null;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.license.License;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.test.ESTestCase;
Expand All @@ -24,18 +23,15 @@

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

public class SecurityBaseRestHandlerTests extends ESTestCase {

public void testSecurityBaseRestHandlerChecksLicenseState() throws Exception {
public void testSecurityBaseRestHandlerChecksFeatureAvailableBeforePreparingRequest() throws Exception {
final boolean securityEnabled = randomBoolean();
Settings settings = Settings.builder().put(XPackSettings.SECURITY_ENABLED.getKey(), securityEnabled).build();
final AtomicBoolean consumerCalled = new AtomicBoolean(false);
final AtomicBoolean innerPrepareRequestCalled = new AtomicBoolean(false);
final XPackLicenseState licenseState = mock(XPackLicenseState.class);
when(licenseState.getOperationMode()).thenReturn(
randomFrom(License.OperationMode.BASIC, License.OperationMode.STANDARD, License.OperationMode.GOLD)
);
SecurityBaseRestHandler handler = new SecurityBaseRestHandler(settings, licenseState) {

@Override
Expand All @@ -50,6 +46,10 @@ public List<Route> routes() {

@Override
protected RestChannelConsumer innerPrepareRequest(RestRequest request, NodeClient client) throws IOException {
if (innerPrepareRequestCalled.compareAndSet(false, true) == false) {
fail("innerPrepareRequestCalled was not false");
}

return channel -> {
if (consumerCalled.compareAndSet(false, true) == false) {
fail("consumerCalled was not false");
Expand All @@ -66,10 +66,12 @@ protected RestChannelConsumer innerPrepareRequest(RestRequest request, NodeClien
handler.handleRequest(fakeRestRequest, fakeRestChannel, client);

if (securityEnabled) {
assertTrue(innerPrepareRequestCalled.get());
assertTrue(consumerCalled.get());
assertEquals(0, fakeRestChannel.responses().get());
assertEquals(0, fakeRestChannel.errors().get());
} else {
assertFalse(innerPrepareRequestCalled.get());
assertFalse(consumerCalled.get());
assertEquals(0, fakeRestChannel.responses().get());
assertEquals(1, fakeRestChannel.errors().get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest;
import org.elasticsearch.xpack.core.XPackSettings;
import org.hamcrest.Matchers;

Expand All @@ -28,13 +27,13 @@ public class EnrollmentBaseRestHandlerTests extends ESTestCase {
public void testInEnrollmentMode() {
final Settings settings = Settings.builder().put(XPackSettings.ENROLLMENT_ENABLED.getKey(), true).build();
final EnrollmentBaseRestHandler handler = buildHandler(settings);
assertThat(handler.checkFeatureAvailable(new FakeRestRequest()), Matchers.nullValue());
assertThat(handler.checkFeatureAvailable(), Matchers.nullValue());
}

public void testNotInEnrollmentMode() {
final Settings settings = Settings.builder().put(XPackSettings.ENROLLMENT_ENABLED.getKey(), false).build();
final EnrollmentBaseRestHandler handler = buildHandler(settings);
Exception ex = handler.checkFeatureAvailable(new FakeRestRequest());
Exception ex = handler.checkFeatureAvailable();
assertThat(ex, instanceOf(ElasticsearchSecurityException.class));
assertThat(
ex.getMessage(),
Expand All @@ -51,7 +50,7 @@ public void testSecurityExplicitlyDisabled() {
.put(XPackSettings.ENROLLMENT_ENABLED.getKey(), true)
.build();
final EnrollmentBaseRestHandler handler = buildHandler(settings);
Exception ex = handler.checkFeatureAvailable(new FakeRestRequest());
Exception ex = handler.checkFeatureAvailable();
assertThat(ex, instanceOf(IllegalStateException.class));
assertThat(ex.getMessage(), Matchers.containsString("Security is not enabled but a security rest handler is registered"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ public void testLicenseEnforcement() {
final boolean featureAllowed = randomBoolean();
when(licenseState.isAllowed(Security.USER_PROFILE_COLLABORATION_FEATURE)).thenReturn(featureAllowed);
if (featureAllowed) {
assertThat(restSuggestProfilesAction.checkFeatureAvailable(new FakeRestRequest()), nullValue());
assertThat(restSuggestProfilesAction.checkFeatureAvailable(), nullValue());
verify(licenseState).featureUsed(Security.USER_PROFILE_COLLABORATION_FEATURE);
} else {
final Exception e = restSuggestProfilesAction.checkFeatureAvailable(new FakeRestRequest());
final Exception e = restSuggestProfilesAction.checkFeatureAvailable();
assertThat(e, instanceOf(ElasticsearchSecurityException.class));
assertThat(e.getMessage(), containsString("current license is non-compliant for [user-profile-collaboration]"));
assertThat(((ElasticsearchSecurityException) e).status(), equalTo(RestStatus.FORBIDDEN));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.elasticsearch.license.internal.XPackLicenseStatus;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest;
import org.elasticsearch.xpack.core.XPackSettings;
import org.hamcrest.Matchers;

Expand All @@ -31,14 +30,14 @@ public void testSamlAvailableOnTrialAndPlatinum() {
final SamlBaseRestHandler handler = buildHandler(
randomFrom(License.OperationMode.TRIAL, License.OperationMode.PLATINUM, License.OperationMode.ENTERPRISE)
);
assertThat(handler.checkFeatureAvailable(new FakeRestRequest()), Matchers.nullValue());
assertThat(handler.checkFeatureAvailable(), Matchers.nullValue());
}

public void testSamlNotAvailableOnBasicStandardOrGold() {
final SamlBaseRestHandler handler = buildHandler(
randomFrom(License.OperationMode.BASIC, License.OperationMode.STANDARD, License.OperationMode.GOLD)
);
Exception e = handler.checkFeatureAvailable(new FakeRestRequest());
Exception e = handler.checkFeatureAvailable();
assertThat(e, instanceOf(ElasticsearchException.class));
ElasticsearchException elasticsearchException = (ElasticsearchException) e;
assertThat(elasticsearchException.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), contains("saml"));
Expand Down
Loading

0 comments on commit 366a9a2

Please sign in to comment.