Skip to content

Commit

Permalink
Support task resource tracking in OpenSearch (#3982)
Browse files Browse the repository at this point in the history
* Support task resource tracking in OpenSearch

* Reopens changes from #2639 (reverted in #3046) to add a framework for task resource tracking. Currently, SearchTask and SearchShardTask support resource tracking but it can be extended to any other task.

* Fixed a race-condition when Task is unregistered before its threads are stopped

* Improved error handling and simplified task resource tracking completion listener

* Avoid registering listeners on already completed tasks

Signed-off-by: Ketan Verma <ketan9495@gmail.com>
  • Loading branch information
ketanv3 committed Aug 2, 2022
1 parent b0080eb commit 5eac54d
Show file tree
Hide file tree
Showing 30 changed files with 1,513 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,9 @@ public void onTaskUnregistered(Task task) {}

@Override
public void waitForTaskCompletion(Task task) {}

@Override
public void taskExecutionStarted(Task task, Boolean closeableInvoked) {}
});
}
// Need to run the task in a separate thread because node client's .execute() is blocked by our task listener
Expand Down Expand Up @@ -651,6 +654,9 @@ public void waitForTaskCompletion(Task task) {
waitForWaitingToStart.countDown();
}

@Override
public void taskExecutionStarted(Task task, Boolean closeableInvoked) {}

@Override
public void onTaskRegistered(Task task) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskInfo;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

Expand All @@ -65,8 +66,15 @@ public static long waitForCompletionTimeout(TimeValue timeout) {

private static final TimeValue DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT = timeValueSeconds(30);

private final TaskResourceTrackingService taskResourceTrackingService;

@Inject
public TransportListTasksAction(ClusterService clusterService, TransportService transportService, ActionFilters actionFilters) {
public TransportListTasksAction(
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
TaskResourceTrackingService taskResourceTrackingService
) {
super(
ListTasksAction.NAME,
clusterService,
Expand All @@ -77,6 +85,7 @@ public TransportListTasksAction(ClusterService clusterService, TransportService
TaskInfo::new,
ThreadPool.Names.MANAGEMENT
);
this.taskResourceTrackingService = taskResourceTrackingService;
}

@Override
Expand Down Expand Up @@ -106,6 +115,8 @@ protected void processTasks(ListTasksRequest request, Consumer<Task> operation)
}
taskManager.waitForTaskCompletion(task, timeoutNanos);
});
} else {
operation = operation.andThen(taskResourceTrackingService::refreshResourceStats);
}
super.processTasks(request, operation);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ public SearchShardTask(long id, String type, String action, String description,
super(id, type, action, description, parentTaskId, headers);
}

@Override
public boolean supportsResourceTracking() {
return true;
}

@Override
public boolean shouldCancelChildrenOnCancellation() {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public final String getDescription() {
return descriptionSupplier.get();
}

@Override
public boolean supportsResourceTracking() {
return true;
}

/**
* Attach a {@link SearchProgressListener} to this task.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.action.ActionResponse;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskCancelledException;
import org.opensearch.tasks.TaskId;
Expand Down Expand Up @@ -93,31 +94,39 @@ public final Task execute(Request request, ActionListener<Response> listener) {
*/
final Releasable unregisterChildNode = registerChildNode(request.getParentTask());
final Task task;

try {
task = taskManager.register("transport", actionName, request);
} catch (TaskCancelledException e) {
unregisterChildNode.close();
throw e;
}
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(response);

ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task);
try {
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(response);
}
}
}

@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(e);
@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(e);
}
}
}
});
});
} finally {
storedContext.close();
}

return task;
}

Expand All @@ -134,25 +143,30 @@ public final Task execute(Request request, TaskListener<Response> listener) {
unregisterChildNode.close();
throw e;
}
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(task, response);
ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task);
try {
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(task, response);
}
}
}

@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(task, e);
@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(task, e);
}
}
}
});
});
} finally {
storedContext.close();
}
return task;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
import org.opensearch.script.ScriptMetadata;
import org.opensearch.snapshots.SnapshotsInfoService;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.tasks.TaskResultsService;

import java.util.ArrayList;
Expand Down Expand Up @@ -396,6 +397,7 @@ protected void configure() {
bind(NodeMappingRefreshAction.class).asEagerSingleton();
bind(MappingUpdatedAction.class).asEagerSingleton();
bind(TaskResultsService.class).asEagerSingleton();
bind(TaskResourceTrackingService.class).asEagerSingleton();
bind(AllocationDeciders.class).toInstance(allocationDeciders);
bind(ShardsAllocator.class).toInstance(shardsAllocator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.opensearch.index.ShardIndexingPressureMemoryManager;
import org.opensearch.index.ShardIndexingPressureSettings;
import org.opensearch.index.ShardIndexingPressureStore;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.watcher.ResourceWatcherService;
import org.opensearch.action.admin.cluster.configuration.TransportAddVotingConfigExclusionsAction;
import org.opensearch.action.admin.indices.close.TransportCloseIndexAction;
Expand Down Expand Up @@ -575,7 +576,8 @@ public void apply(Settings value, Settings current, Settings previous) {
ShardIndexingPressureMemoryManager.THROUGHPUT_DEGRADATION_LIMITS,
ShardIndexingPressureMemoryManager.SUCCESSFUL_REQUEST_ELAPSED_TIMEOUT,
ShardIndexingPressureMemoryManager.MAX_OUTSTANDING_REQUESTS,
IndexingPressure.MAX_INDEXING_BYTES
IndexingPressure.MAX_INDEXING_BYTES,
TaskResourceTrackingService.TASK_RESOURCE_TRACKING_ENABLED
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.node.Node;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.threadpool.TaskAwareRunnable;

import java.util.List;
import java.util.Optional;
Expand All @@ -55,6 +57,7 @@
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -182,21 +185,32 @@ public static OpenSearchThreadPoolExecutor newResizable(
int size,
int queueCapacity,
ThreadFactory threadFactory,
ThreadContext contextHolder
ThreadContext contextHolder,
AtomicReference<RunnableTaskExecutionListener> runnableTaskListener
) {

if (queueCapacity <= 0) {
throw new IllegalArgumentException("queue capacity for [" + name + "] executor must be positive, got: " + queueCapacity);
}

Function<Runnable, WrappedRunnable> runnableWrapper;
if (runnableTaskListener != null) {
runnableWrapper = (runnable) -> {
TaskAwareRunnable taskAwareRunnable = new TaskAwareRunnable(contextHolder, runnable, runnableTaskListener);
return new TimedRunnable(taskAwareRunnable);
};
} else {
runnableWrapper = TimedRunnable::new;
}

return new QueueResizableOpenSearchThreadPoolExecutor(
name,
size,
size,
0,
TimeUnit.MILLISECONDS,
new ResizableBlockingQueue<>(ConcurrentCollections.<Runnable>newBlockingQueue(), queueCapacity),
TimedRunnable::new,
runnableWrapper,
threadFactory,
new OpenSearchAbortPolicy(),
contextHolder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@

import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_COUNT;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_SIZE;
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;

/**
* A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with
Expand Down Expand Up @@ -135,16 +136,23 @@ public StoredContext stashContext() {
* This is needed so the DeprecationLogger in another thread can see the value of X-Opaque-ID provided by a user.
* Otherwise when context is stash, it should be empty.
*/

ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT;

if (context.requestHeaders.containsKey(Task.X_OPAQUE_ID)) {
ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT.putHeaders(
threadContextStruct = threadContextStruct.putHeaders(
MapBuilder.<String, String>newMapBuilder()
.put(Task.X_OPAQUE_ID, context.requestHeaders.get(Task.X_OPAQUE_ID))
.immutableMap()
);
threadLocal.set(threadContextStruct);
} else {
threadLocal.set(DEFAULT_CONTEXT);
}

if (context.transientHeaders.containsKey(TASK_ID)) {
threadContextStruct = threadContextStruct.putTransient(TASK_ID, context.transientHeaders.get(TASK_ID));
}

threadLocal.set(threadContextStruct);

return () -> {
// If the node and thus the threadLocal get closed while this task
// is still executing, we don't want this runnable to fail with an
Expand Down
14 changes: 12 additions & 2 deletions server/src/main/java/org/opensearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.cluster.routing.allocation.AwarenessReplicaBalance;
import org.opensearch.index.IndexingPressureService;
import org.opensearch.index.store.RemoteDirectoryFactory;
import org.opensearch.indices.replication.SegmentReplicationSourceFactory;
import org.opensearch.indices.replication.SegmentReplicationTargetService;
import org.opensearch.indices.replication.SegmentReplicationSourceService;
import org.opensearch.index.store.RemoteDirectoryFactory;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.watcher.ResourceWatcherService;
import org.opensearch.Assertions;
import org.opensearch.Build;
Expand Down Expand Up @@ -219,6 +221,7 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -338,6 +341,7 @@ public static class DiscoverySettings {
private final LocalNodeFactory localNodeFactory;
private final NodeService nodeService;
final NamedWriteableRegistry namedWriteableRegistry;
private final AtomicReference<RunnableTaskExecutionListener> runnableTaskListener;

public Node(Environment environment) {
this(environment, Collections.emptyList(), true);
Expand Down Expand Up @@ -447,7 +451,8 @@ protected Node(

final List<ExecutorBuilder<?>> executorBuilders = pluginsService.getExecutorBuilders(settings);

final ThreadPool threadPool = new ThreadPool(settings, executorBuilders.toArray(new ExecutorBuilder[0]));
runnableTaskListener = new AtomicReference<>();
final ThreadPool threadPool = new ThreadPool(settings, runnableTaskListener, executorBuilders.toArray(new ExecutorBuilder[0]));
resourcesToClose.add(() -> ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS));
final ResourceWatcherService resourceWatcherService = new ResourceWatcherService(settings, threadPool);
resourcesToClose.add(resourceWatcherService);
Expand Down Expand Up @@ -1095,6 +1100,11 @@ public Node start() throws NodeValidationException {
TransportService transportService = injector.getInstance(TransportService.class);
transportService.getTaskManager().setTaskResultsService(injector.getInstance(TaskResultsService.class));
transportService.getTaskManager().setTaskCancellationService(new TaskCancellationService(transportService));

TaskResourceTrackingService taskResourceTrackingService = injector.getInstance(TaskResourceTrackingService.class);
transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService);
runnableTaskListener.set(taskResourceTrackingService);

transportService.start();
assert localNodeFactory.getNode() != null;
assert transportService.getLocalNode().equals(localNodeFactory.getNode())
Expand Down
Loading

0 comments on commit 5eac54d

Please sign in to comment.