Skip to content

Commit

Permalink
Use system context for cluster state update tasks (#31241)
Browse files Browse the repository at this point in the history
This commit makes it so that cluster state update tasks always run under the system context, only
restoring the original context when the listener that was provided with the task is called. A notable
exception is the clusterStatePublished(...) callback which will still run under system context,
because it's defined on the executor-level, and not the task level, and only called once for the
combined batch of tasks and can therefore not be uniquely identified with a task / thread context.

Relates #30603
  • Loading branch information
ywelsch authored Jun 18, 2018
1 parent 1502812 commit 02a4ef3
Show file tree
Hide file tree
Showing 19 changed files with 236 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ default boolean runOnlyOnMaster() {
/**
* Callback invoked after new cluster state is published. Note that
* this method is not invoked if the cluster state was not updated.
*
* Note that this method will be executed using system context.
*
* @param clusterChangedEvent the change event for this cluster state change, containing
* both old and new states
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ public String describeTasks(List<ClusterStateUpdateTask> tasks) {
*/
public abstract void onFailure(String source, Exception e);

@Override
public final void clusterStatePublished(ClusterChangedEvent clusterChangedEvent) {
// final, empty implementation here as this method should only be defined in combination
// with a batching executor as it will always be executed within the system context.
}

/**
* If the cluster state update task wasn't processed by the provided timeout, call
* {@link ClusterStateTaskListener#onFailure(String, Exception)}. May return null to indicate no timeout is needed (default).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.FutureUtils;
import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.discovery.Discovery;
import org.elasticsearch.threadpool.ThreadPool;

Expand All @@ -59,6 +60,7 @@
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static org.elasticsearch.cluster.service.ClusterService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING;
Expand Down Expand Up @@ -426,26 +428,28 @@ public TimeValue getMaxTaskWaitTime() {
return threadPoolExecutor.getMaxTaskWaitTime();
}

private SafeClusterStateTaskListener safe(ClusterStateTaskListener listener) {
private SafeClusterStateTaskListener safe(ClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> contextSupplier) {
if (listener instanceof AckedClusterStateTaskListener) {
return new SafeAckedClusterStateTaskListener((AckedClusterStateTaskListener) listener, logger);
return new SafeAckedClusterStateTaskListener((AckedClusterStateTaskListener) listener, contextSupplier, logger);
} else {
return new SafeClusterStateTaskListener(listener, logger);
return new SafeClusterStateTaskListener(listener, contextSupplier, logger);
}
}

private static class SafeClusterStateTaskListener implements ClusterStateTaskListener {
private final ClusterStateTaskListener listener;
protected final Supplier<ThreadContext.StoredContext> context;
private final Logger logger;

SafeClusterStateTaskListener(ClusterStateTaskListener listener, Logger logger) {
SafeClusterStateTaskListener(ClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> context, Logger logger) {
this.listener = listener;
this.context = context;
this.logger = logger;
}

@Override
public void onFailure(String source, Exception e) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onFailure(source, e);
} catch (Exception inner) {
inner.addSuppressed(e);
Expand All @@ -456,7 +460,7 @@ public void onFailure(String source, Exception e) {

@Override
public void onNoLongerMaster(String source) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onNoLongerMaster(source);
} catch (Exception e) {
logger.error(() -> new ParameterizedMessage(
Expand All @@ -466,7 +470,7 @@ public void onNoLongerMaster(String source) {

@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.clusterStateProcessed(source, oldState, newState);
} catch (Exception e) {
logger.error(() -> new ParameterizedMessage(
Expand All @@ -480,8 +484,9 @@ private static class SafeAckedClusterStateTaskListener extends SafeClusterStateT
private final AckedClusterStateTaskListener listener;
private final Logger logger;

SafeAckedClusterStateTaskListener(AckedClusterStateTaskListener listener, Logger logger) {
super(listener, logger);
SafeAckedClusterStateTaskListener(AckedClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> context,
Logger logger) {
super(listener, context, logger);
this.listener = listener;
this.logger = logger;
}
Expand All @@ -493,7 +498,7 @@ public boolean mustAck(DiscoveryNode discoveryNode) {

@Override
public void onAllNodesAcked(@Nullable Exception e) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onAllNodesAcked(e);
} catch (Exception inner) {
inner.addSuppressed(e);
Expand All @@ -503,7 +508,7 @@ public void onAllNodesAcked(@Nullable Exception e) {

@Override
public void onAckTimeout() {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onAckTimeout();
} catch (Exception e) {
logger.error("exception thrown by listener while notifying on ack timeout", e);
Expand Down Expand Up @@ -724,9 +729,13 @@ public <T> void submitStateUpdateTasks(final String source,
if (!lifecycle.started()) {
return;
}
try {
final ThreadContext threadContext = threadPool.getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(false);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();

List<Batcher.UpdateTask> safeTasks = tasks.entrySet().stream()
.map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), safe(e.getValue()), executor))
.map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), safe(e.getValue(), supplier), executor))
.collect(Collectors.toList());
taskBatcher.submitTasks(safeTasks, config.timeout());
} catch (EsRejectedExecutionException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ public ClusterStateResponse newInstance() {

@Override
public void handleResponse(ClusterStateResponse response) {
assert transportService.getThreadPool().getThreadContext().isSystemContext() == false : "context is a system context";
try {
if (remoteClusterName.get() == null) {
assert response.getClusterName().value() != null;
Expand Down Expand Up @@ -597,7 +596,6 @@ public void handleResponse(ClusterStateResponse response) {

@Override
public void handleException(TransportException exp) {
assert transportService.getThreadPool().getThreadContext().isSystemContext() == false : "context is a system context";
logger.warn(() -> new ParameterizedMessage("fetching nodes from external cluster {} failed", clusterAlias), exp);
try {
IOUtils.closeWhileHandlingException(connection);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@
import org.elasticsearch.cluster.block.ClusterBlocks;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.BaseFuture;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.discovery.Discovery;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.MockLogAppender;
Expand All @@ -52,6 +54,7 @@
import org.junit.BeforeClass;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -168,6 +171,85 @@ public void onFailure(String source, Exception e) {
nonMaster.close();
}

public void testThreadContext() throws InterruptedException {
final TimedMasterService master = createTimedMasterService(true);
final CountDownLatch latch = new CountDownLatch(1);

try (ThreadContext.StoredContext ignored = threadPool.getThreadContext().stashContext()) {
final Map<String, String> expectedHeaders = Collections.singletonMap("test", "test");
threadPool.getThreadContext().putHeader(expectedHeaders);

final TimeValue ackTimeout = randomBoolean() ? TimeValue.ZERO : TimeValue.timeValueMillis(randomInt(10000));
final TimeValue masterTimeout = randomBoolean() ? TimeValue.ZERO : TimeValue.timeValueMillis(randomInt(10000));

master.submitStateUpdateTask("test", new AckedClusterStateUpdateTask<Void>(null, null) {
@Override
public ClusterState execute(ClusterState currentState) {
assertTrue(threadPool.getThreadContext().isSystemContext());
assertEquals(Collections.emptyMap(), threadPool.getThreadContext().getHeaders());

if (randomBoolean()) {
return ClusterState.builder(currentState).build();
} else if (randomBoolean()) {
return currentState;
} else {
throw new IllegalArgumentException("mock failure");
}
}

@Override
public void onFailure(String source, Exception e) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
latch.countDown();
}

@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
latch.countDown();
}

@Override
protected Void newResponse(boolean acknowledged) {
return null;
}

public TimeValue ackTimeout() {
return ackTimeout;
}

@Override
public TimeValue timeout() {
return masterTimeout;
}

@Override
public void onAllNodesAcked(@Nullable Exception e) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
latch.countDown();
}

@Override
public void onAckTimeout() {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
latch.countDown();
}

});

assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
}

latch.await();

master.close();
}

/*
* test that a listener throwing an exception while handling a
* notification does not prevent publication notification to the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand Down Expand Up @@ -293,7 +292,7 @@ public Builder deleteJob(String jobId, PersistentTasksCustomMetaData tasks) {
return this;
}

public Builder putDatafeed(DatafeedConfig datafeedConfig, ThreadContext threadContext) {
public Builder putDatafeed(DatafeedConfig datafeedConfig, Map<String, String> headers) {
if (datafeeds.containsKey(datafeedConfig.getId())) {
throw new ResourceAlreadyExistsException("A datafeed with id [" + datafeedConfig.getId() + "] already exists");
}
Expand All @@ -302,13 +301,13 @@ public Builder putDatafeed(DatafeedConfig datafeedConfig, ThreadContext threadCo
Job job = jobs.get(jobId);
DatafeedJobValidator.validate(datafeedConfig, job);

if (threadContext != null) {
if (headers.isEmpty() == false) {
// Adjust the request, adding security headers from the current thread context
DatafeedConfig.Builder builder = new DatafeedConfig.Builder(datafeedConfig);
Map<String, String> headers = threadContext.getHeaders().entrySet().stream()
Map<String, String> securityHeaders = headers.entrySet().stream()
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(headers);
builder.setHeaders(securityHeaders);
datafeedConfig = builder.build();
}

Expand All @@ -328,15 +327,15 @@ private void checkJobIsAvailableForDatafeed(String jobId) {
}
}

public Builder updateDatafeed(DatafeedUpdate update, PersistentTasksCustomMetaData persistentTasks, ThreadContext threadContext) {
public Builder updateDatafeed(DatafeedUpdate update, PersistentTasksCustomMetaData persistentTasks, Map<String, String> headers) {
String datafeedId = update.getId();
DatafeedConfig oldDatafeedConfig = datafeeds.get(datafeedId);
if (oldDatafeedConfig == null) {
throw ExceptionsHelper.missingDatafeedException(datafeedId);
}
checkDatafeedIsStopped(() -> Messages.getMessage(Messages.DATAFEED_CANNOT_UPDATE_IN_CURRENT_STATE, datafeedId,
DatafeedState.STARTED), datafeedId, persistentTasks);
DatafeedConfig newDatafeedConfig = update.apply(oldDatafeedConfig, threadContext);
DatafeedConfig newDatafeedConfig = update.apply(oldDatafeedConfig, headers);
if (newDatafeedConfig.getJobId().equals(oldDatafeedConfig.getJobId()) == false) {
checkJobIsAvailableForDatafeed(newDatafeedConfig.getJobId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand Down Expand Up @@ -264,7 +263,7 @@ ChunkingConfig getChunkingConfig() {
* Applies the update to the given {@link DatafeedConfig}
* @return a new {@link DatafeedConfig} that contains the update
*/
public DatafeedConfig apply(DatafeedConfig datafeedConfig, ThreadContext threadContext) {
public DatafeedConfig apply(DatafeedConfig datafeedConfig, Map<String, String> headers) {
if (id.equals(datafeedConfig.getId()) == false) {
throw new IllegalArgumentException("Cannot apply update to datafeedConfig with different id");
}
Expand Down Expand Up @@ -301,12 +300,12 @@ public DatafeedConfig apply(DatafeedConfig datafeedConfig, ThreadContext threadC
builder.setChunkingConfig(chunkingConfig);
}

if (threadContext != null) {
if (headers.isEmpty() == false) {
// Adjust the request, adding security headers from the current thread context
Map<String, String> headers = threadContext.getHeaders().entrySet().stream()
Map<String, String> securityHeaders = headers.entrySet().stream()
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(headers);
builder.setHeaders(securityHeaders);
}

return builder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public void testApply_failBecauseTargetDatafeedHasDifferentId() {

public void testApply_givenEmptyUpdate() {
DatafeedConfig datafeed = DatafeedConfigTests.createRandomizedDatafeedConfig("foo");
DatafeedConfig updatedDatafeed = new DatafeedUpdate.Builder(datafeed.getId()).build().apply(datafeed, null);
DatafeedConfig updatedDatafeed = new DatafeedUpdate.Builder(datafeed.getId()).build().apply(datafeed, Collections.emptyMap());
assertThat(datafeed, equalTo(updatedDatafeed));
}

Expand All @@ -125,7 +125,7 @@ public void testApply_givenPartialUpdate() {

DatafeedUpdate.Builder updated = new DatafeedUpdate.Builder(datafeed.getId());
updated.setScrollSize(datafeed.getScrollSize() + 1);
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, null);
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, Collections.emptyMap());

DatafeedConfig.Builder expectedDatafeed = new DatafeedConfig.Builder(datafeed);
expectedDatafeed.setScrollSize(datafeed.getScrollSize() + 1);
Expand All @@ -149,7 +149,7 @@ public void testApply_givenFullUpdateNoAggregations() {
update.setScrollSize(8000);
update.setChunkingConfig(ChunkingConfig.newManual(TimeValue.timeValueHours(1)));

DatafeedConfig updatedDatafeed = update.build().apply(datafeed, null);
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, Collections.emptyMap());

assertThat(updatedDatafeed.getJobId(), equalTo("bar"));
assertThat(updatedDatafeed.getIndices(), equalTo(Collections.singletonList("i_2")));
Expand All @@ -175,7 +175,7 @@ public void testApply_givenAggregations() {
update.setAggregations(new AggregatorFactories.Builder().addAggregator(
AggregationBuilders.histogram("a").interval(300000).field("time").subAggregation(maxTime)));

DatafeedConfig updatedDatafeed = update.build().apply(datafeed, null);
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, Collections.emptyMap());

assertThat(updatedDatafeed.getIndices(), equalTo(Collections.singletonList("i_1")));
assertThat(updatedDatafeed.getTypes(), equalTo(Collections.singletonList("t_1")));
Expand Down
Loading

0 comments on commit 02a4ef3

Please sign in to comment.