Skip to content

Commit

Permalink
add additional tests for searchBackpressureService and refactor code
Browse files Browse the repository at this point in the history
Signed-off-by: Kaushal Kumar <ravi.kaushal97@gmail.com>
  • Loading branch information
kaushalmahi12 committed Apr 30, 2024
1 parent f4e1d6e commit aa4fd2b
Show file tree
Hide file tree
Showing 18 changed files with 459 additions and 212 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.util.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.DoubleSupplier;
import java.util.function.LongSupplier;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -166,44 +170,24 @@ void doRun() {

List<CancellableTask> searchTasks = getTaskByType(SearchTask.class);
List<CancellableTask> searchShardTasks = getTaskByType(SearchShardTask.class);
List<CancellableTask> cancellableTasks = new ArrayList<>();

// Force-refresh usage stats of these tasks before making a cancellation decision.
taskResourceTrackingService.refreshResourceStats(searchTasks.toArray(new Task[0]));
taskResourceTrackingService.refreshResourceStats(searchShardTasks.toArray(new Task[0]));

List<TaskCancellation> taskCancellations = new ArrayList<>();

addHeapBasedTaskCancellations(taskCancellations, searchTasks, searchShardTasks);
taskCancellations = addHeapBasedTaskCancellations(taskCancellations, searchTasks, searchShardTasks);

addCPUBasedTaskCancellations(taskCancellations, searchTasks, searchShardTasks);
taskCancellations = addCPUBasedTaskCancellations(taskCancellations, searchTasks, searchShardTasks);

addElapsedTimeBasedTaskCancellations(taskCancellations, searchTasks, searchShardTasks);
taskCancellations = addElapsedTimeBasedTaskCancellations(taskCancellations, searchTasks, searchShardTasks);

// Since these cancellations might be duplicate due to multiple trackers causing cancellation for same task
// We need to merge them
taskCancellations = mergeTaskCancellations(taskCancellations);

// // Check if increase in heap usage is due to SearchTasks
// if (HeapUsageTracker.isHeapUsageDominatedBySearch(
// searchTasks,
// getSettings().getSearchTaskSettings().getTotalHeapPercentThreshold()
// )) {
// cancellableTasks.addAll(searchTasks);
// }
//
// // Check if increase in heap usage is due to SearchShardTasks
// if (HeapUsageTracker.isHeapUsageDominatedBySearch(
// searchShardTasks,
// getSettings().getSearchShardTaskSettings().getTotalHeapPercentThreshold()
// )) {
// cancellableTasks.addAll(searchShardTasks);
// }

// none of the task type is breaching the heap usage thresholds and hence we do not cancel any tasks
// if (taskCancellations.isEmpty()) {
// return;
// }
taskCancellations = mergeTaskCancellations(taskCancellations).stream()
.filter(TaskCancellation::isEligibleForCancellation)
.collect(Collectors.toList());

for (TaskCancellation taskCancellation : taskCancellations) {
logger.warn(
Expand Down Expand Up @@ -235,51 +219,86 @@ void doRun() {
}
}

private void addElapsedTimeBasedTaskCancellations(List<TaskCancellation> taskCancellations, List<CancellableTask> searchTasks, List<CancellableTask> searchShardTasks) {
final TaskResourceUsageTrackers.TaskResourceUsageTracker searchTaskElapsedTimeTracker = getTaskResourceUsageTrackersByType(SearchTask.class).getElapsedTimeTracker();
final TaskResourceUsageTrackers.TaskResourceUsageTracker searchShardTaskElapsedTimeTracker = getTaskResourceUsageTrackersByType(SearchShardTask.class).getElapsedTimeTracker();

taskCancellations.addAll(
searchTaskElapsedTimeTracker.getTaskCancellations(searchTasks, searchBackpressureStates.get(SearchTask.class)::incrementCancellationCount)
private List<TaskCancellation> addElapsedTimeBasedTaskCancellations(
List<TaskCancellation> taskCancellations,
List<CancellableTask> searchTasks,
List<CancellableTask> searchShardTasks
) {
final Optional<TaskResourceUsageTrackers.TaskResourceUsageTracker> searchTaskElapsedTimeTracker =
getTaskResourceUsageTrackersByType(SearchTask.class).getElapsedTimeTracker();
final Optional<TaskResourceUsageTrackers.TaskResourceUsageTracker> searchShardTaskElapsedTimeTracker =
getTaskResourceUsageTrackersByType(SearchShardTask.class).getElapsedTimeTracker();

addTaskCancellationsFromTaskResourceUsageTracker(taskCancellations, searchTasks, searchTaskElapsedTimeTracker, SearchTask.class);

addTaskCancellationsFromTaskResourceUsageTracker(
taskCancellations,
searchShardTasks,
searchShardTaskElapsedTimeTracker,
SearchShardTask.class
);

taskCancellations.addAll(
searchShardTaskElapsedTimeTracker.getTaskCancellations(searchShardTasks, searchBackpressureStates.get(SearchShardTask.class)::incrementCancellationCount)
);
return taskCancellations;
}

private void addCPUBasedTaskCancellations(List<TaskCancellation> taskCancellations, List<CancellableTask> searchTasks, List<CancellableTask> searchShardTasks) {
private List<TaskCancellation> addCPUBasedTaskCancellations(
List<TaskCancellation> taskCancellations,
List<CancellableTask> searchTasks,
List<CancellableTask> searchShardTasks
) {
if (nodeDuressTrackers.isCPUInDuress()) {
final TaskResourceUsageTrackers.TaskResourceUsageTracker searchTaskCPUUsageTracker = getTaskResourceUsageTrackersByType(SearchTask.class).getCpuUsageTracker();
final TaskResourceUsageTrackers.TaskResourceUsageTracker searchShardTaskCPUUsageTracker = getTaskResourceUsageTrackersByType(SearchShardTask.class).getCpuUsageTracker();

taskCancellations.addAll(
searchTaskCPUUsageTracker
.getTaskCancellations(searchTasks, searchBackpressureStates.get(SearchTask.class)::incrementCancellationCount)
);

taskCancellations.addAll(
searchShardTaskCPUUsageTracker
.getTaskCancellations(searchShardTasks, searchBackpressureStates.get(SearchTask.class)::incrementCancellationCount)
final Optional<TaskResourceUsageTrackers.TaskResourceUsageTracker> searchTaskCPUUsageTracker =
getTaskResourceUsageTrackersByType(SearchTask.class).getCpuUsageTracker();
final Optional<TaskResourceUsageTrackers.TaskResourceUsageTracker> searchShardTaskCPUUsageTracker =
getTaskResourceUsageTrackersByType(SearchShardTask.class).getCpuUsageTracker();

addTaskCancellationsFromTaskResourceUsageTracker(taskCancellations, searchTasks, searchTaskCPUUsageTracker, SearchTask.class);

addTaskCancellationsFromTaskResourceUsageTracker(
taskCancellations,
searchShardTasks,
searchShardTaskCPUUsageTracker,
SearchShardTask.class
);
}
return taskCancellations;
}

private void addHeapBasedTaskCancellations(List<TaskCancellation> taskCancellations, List<CancellableTask> searchTasks, List<CancellableTask> searchShardTasks) {
private List<TaskCancellation> addHeapBasedTaskCancellations(
List<TaskCancellation> taskCancellations,
List<CancellableTask> searchTasks,
List<CancellableTask> searchShardTasks
) {
if (isHeapTrackingSupported() && nodeDuressTrackers.isHeapInDuress()) {
final TaskResourceUsageTrackers.TaskResourceUsageTracker searchTaskHeapUsageTracker = getTaskResourceUsageTrackersByType(SearchTask.class).getHeapUsageTracker();
final TaskResourceUsageTrackers.TaskResourceUsageTracker searchShardTaskHeapUsageTracker = getTaskResourceUsageTrackersByType(SearchShardTask.class).getHeapUsageTracker();

taskCancellations = searchTaskHeapUsageTracker
.getTaskCancellations(searchTasks, searchBackpressureStates.get(SearchTask.class)::incrementCancellationCount);

taskCancellations.addAll(
searchShardTaskHeapUsageTracker
.getTaskCancellations(searchShardTasks, searchBackpressureStates.get(SearchShardTask.class)::incrementCompletionCount)
final Optional<TaskResourceUsageTrackers.TaskResourceUsageTracker> searchTaskHeapUsageTracker =
getTaskResourceUsageTrackersByType(SearchTask.class).getHeapUsageTracker();
final Optional<TaskResourceUsageTrackers.TaskResourceUsageTracker> searchShardTaskHeapUsageTracker =
getTaskResourceUsageTrackersByType(SearchShardTask.class).getHeapUsageTracker();

addTaskCancellationsFromTaskResourceUsageTracker(taskCancellations, searchTasks, searchTaskHeapUsageTracker, SearchTask.class);

addTaskCancellationsFromTaskResourceUsageTracker(
taskCancellations,
searchShardTasks,
searchShardTaskHeapUsageTracker,
SearchShardTask.class
);
}
return taskCancellations;
}

private void addTaskCancellationsFromTaskResourceUsageTracker(
List<TaskCancellation> taskCancellations,
List<CancellableTask> tasks,
Optional<TaskResourceUsageTrackers.TaskResourceUsageTracker> taskResourceUsageTracker,
Class<?> type
) {
taskResourceUsageTracker.ifPresent(
tracker -> taskCancellations.addAll(
tracker.getTaskCancellations(tasks, searchBackpressureStates.get(type)::incrementCancellationCount)
)
);
}

/**
* returns the taskTrackers for given type
Expand All @@ -290,7 +309,6 @@ private TaskResourceUsageTrackers getTaskResourceUsageTrackersByType(Class<? ext
return taskTrackers.get(type);
}


/**
* Method to reduce the taskCancellations into unique bunch
* @param taskCancellations
Expand All @@ -299,14 +317,14 @@ private TaskResourceUsageTrackers getTaskResourceUsageTrackersByType(Class<? ext
private List<TaskCancellation> mergeTaskCancellations(final List<TaskCancellation> taskCancellations) {
final Map<Long, TaskCancellation> uniqueTaskCancellations = new HashMap<>();

for (TaskCancellation taskCancellation: taskCancellations) {
for (TaskCancellation taskCancellation : taskCancellations) {
final long taskId = taskCancellation.getTask().getId();
uniqueTaskCancellations.put(taskId,
uniqueTaskCancellations.getOrDefault(taskId, taskCancellation).merge(taskCancellation));
uniqueTaskCancellations.put(taskId, uniqueTaskCancellations.getOrDefault(taskId, taskCancellation).merge(taskCancellation));
}

return new ArrayList<>(uniqueTaskCancellations.values());
}

/**
* Given a task, returns the type of the task
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import org.opensearch.search.backpressure.trackers.CpuUsageTracker;
import org.opensearch.search.backpressure.trackers.ElapsedTimeTracker;
import org.opensearch.search.backpressure.trackers.HeapUsageTracker;
import org.opensearch.search.backpressure.trackers.TaskResourceUsageTrackers;
import org.opensearch.search.backpressure.trackers.TaskResourceUsageTrackerType;
import org.opensearch.search.backpressure.trackers.TaskResourceUsageTrackers;

import java.io.IOException;
import java.util.Map;
Expand Down Expand Up @@ -67,7 +67,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();

builder.startObject("resource_tracker_stats");
for (Map.Entry<TaskResourceUsageTrackerType, TaskResourceUsageTrackers.TaskResourceUsageTracker.Stats> entry : resourceUsageTrackerStats.entrySet()) {
for (Map.Entry<
TaskResourceUsageTrackerType,
TaskResourceUsageTrackers.TaskResourceUsageTracker.Stats> entry : resourceUsageTrackerStats.entrySet()) {
builder.field(entry.getKey().getName(), entry.getValue());
}
builder.endObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();

builder.startObject("resource_tracker_stats");
for (Map.Entry<TaskResourceUsageTrackerType, TaskResourceUsageTrackers.TaskResourceUsageTracker.Stats> entry : resourceUsageTrackerStats.entrySet()) {
for (Map.Entry<
TaskResourceUsageTrackerType,
TaskResourceUsageTrackers.TaskResourceUsageTracker.Stats> entry : resourceUsageTrackerStats.entrySet()) {
builder.field(entry.getKey().getName(), entry.getValue());
}
builder.endObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,35 +34,37 @@ public class CpuUsageTracker extends TaskResourceUsageTrackers.TaskResourceUsage
private final LongSupplier thresholdSupplier;

public CpuUsageTracker(LongSupplier thresholdSupplier) {
this(thresholdSupplier, (task) -> {
long usage = task.getTotalResourceStats().getCpuTimeInNanos();
long threshold = thresholdSupplier.getAsLong();

if (usage < threshold) {
return Optional.empty();
}

return Optional.of(
new TaskCancellation.Reason(
"cpu usage exceeded ["
+ new TimeValue(usage, TimeUnit.NANOSECONDS)
+ " >= "
+ new TimeValue(threshold, TimeUnit.NANOSECONDS)
+ "]",
1 // TODO: fine-tune the cancellation score/weight
)
);
});
}

public CpuUsageTracker(LongSupplier thresholdSupplier, ResourceUsageBreachEvaluator resourceUsageBreachEvaluator) {
this.thresholdSupplier = thresholdSupplier;
this.resourceUsageBreachEvaluator = resourceUsageBreachEvaluator;
}

@Override
public String name() {
return CPU_USAGE_TRACKER.getName();
}

@Override
public Optional<TaskCancellation.Reason> checkAndMaybeGetCancellationReason(Task task) {
long usage = task.getTotalResourceStats().getCpuTimeInNanos();
long threshold = thresholdSupplier.getAsLong();

if (usage < threshold) {
return Optional.empty();
}

return Optional.of(
new TaskCancellation.Reason(
"cpu usage exceeded ["
+ new TimeValue(usage, TimeUnit.NANOSECONDS)
+ " >= "
+ new TimeValue(threshold, TimeUnit.NANOSECONDS)
+ "]",
1 // TODO: fine-tune the cancellation score/weight
)
);
}

@Override
public TaskResourceUsageTrackers.TaskResourceUsageTracker.Stats stats(List<? extends Task> activeTasks) {
long currentMax = activeTasks.stream().mapToLong(t -> t.getTotalResourceStats().getCpuTimeInNanos()).max().orElse(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,42 @@ public class ElapsedTimeTracker extends TaskResourceUsageTrackers.TaskResourceUs
private final LongSupplier timeNanosSupplier;

public ElapsedTimeTracker(LongSupplier thresholdSupplier, LongSupplier timeNanosSupplier) {
this(thresholdSupplier, timeNanosSupplier, (Task task) -> {
long usage = timeNanosSupplier.getAsLong() - task.getStartTimeNanos();
long threshold = thresholdSupplier.getAsLong();

if (usage < threshold) {
return Optional.empty();
}

return Optional.of(
new TaskCancellation.Reason(
"elapsed time exceeded ["
+ new TimeValue(usage, TimeUnit.NANOSECONDS)
+ " >= "
+ new TimeValue(threshold, TimeUnit.NANOSECONDS)
+ "]",
1 // TODO: fine-tune the cancellation score/weight
)
);
});
}

public ElapsedTimeTracker(
LongSupplier thresholdSupplier,
LongSupplier timeNanosSupplier,
ResourceUsageBreachEvaluator resourceUsageBreachEvaluator
) {
this.thresholdSupplier = thresholdSupplier;
this.timeNanosSupplier = timeNanosSupplier;
this.resourceUsageBreachEvaluator = resourceUsageBreachEvaluator;
}

@Override
public String name() {
return ELAPSED_TIME_TRACKER.getName();
}

@Override
public Optional<TaskCancellation.Reason> checkAndMaybeGetCancellationReason(Task task) {
long usage = timeNanosSupplier.getAsLong() - task.getStartTimeNanos();
long threshold = thresholdSupplier.getAsLong();

if (usage < threshold) {
return Optional.empty();
}

return Optional.of(
new TaskCancellation.Reason(
"elapsed time exceeded ["
+ new TimeValue(usage, TimeUnit.NANOSECONDS)
+ " >= "
+ new TimeValue(threshold, TimeUnit.NANOSECONDS)
+ "]",
1 // TODO: fine-tune the cancellation score/weight
)
);
}

@Override
public TaskResourceUsageTrackers.TaskResourceUsageTracker.Stats stats(List<? extends Task> activeTasks) {
long now = timeNanosSupplier.getAsLong();
Expand Down
Loading

0 comments on commit aa4fd2b

Please sign in to comment.