Skip to content

Commit

Permalink
Collect and merge response headers in ESQL (elastic#99926)
Browse files Browse the repository at this point in the history
It seems that our infrastructure does not merge response headers across 
multiple asynchronous calls. I can reproduce this issue using the
TransportService. Response headers are not merged properly in this scenario:

1. The caller initiates two asynchronous calls, c1 and c2, which can involve
network requests.

2. c1 responded with a warning in the header responses. We merge these response
headers with the original ThreadContext of the calling thread and update the
ThreadContext of the current thread (leaving the calling thread untouched).

3. c2 responded with no warning in the header responses. Since the original
ThreadContext of the calling thread did not get updated after c1, as it's
immutable, we won't be able to merge response headers between c1 and c2.

4. The caller receives a response from the responding thread of c2 without any
warning.

This PR manually collect and merge response headers in DriverRunner. I think we 
should generalize this pattern for Elasticsearch.
  • Loading branch information
dnhatn authored Sep 27, 2023
1 parent 096cf81 commit 6c40a96
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,29 @@

import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.tasks.TaskCancelledException;

import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;

/**
* Run a set of drivers to completion.
*/
public abstract class DriverRunner {
private final ThreadContext threadContext;

public DriverRunner(ThreadContext threadContext) {
this.threadContext = threadContext;
}

/**
* Start a driver.
*/
Expand All @@ -30,8 +42,11 @@ public abstract class DriverRunner {
*/
public void runToCompletion(List<Driver> drivers, ActionListener<Void> listener) {
AtomicReference<Exception> failure = new AtomicReference<>();
AtomicArray<Map<String, List<String>>> responseHeaders = new AtomicArray<>(drivers.size());
CountDown counter = new CountDown(drivers.size());
for (Driver driver : drivers) {
for (int i = 0; i < drivers.size(); i++) {
Driver driver = drivers.get(i);
int driverIndex = i;
ActionListener<Void> driverListener = new ActionListener<>() {
@Override
public void onResponse(Void unused) {
Expand Down Expand Up @@ -66,7 +81,9 @@ public void onFailure(Exception e) {
}

private void done() {
responseHeaders.setOnce(driverIndex, threadContext.getResponseHeaders());
if (counter.countDown()) {
mergeResponseHeaders(responseHeaders);
for (Driver d : drivers) {
if (d.status().status() == DriverStatus.Status.QUEUED) {
d.close();
Expand All @@ -87,4 +104,23 @@ private void done() {
start(driver, driverListener);
}
}

private void mergeResponseHeaders(AtomicArray<Map<String, List<String>>> responseHeaders) {
final Map<String, Set<String>> merged = new HashMap<>();
for (int i = 0; i < responseHeaders.length(); i++) {
final Map<String, List<String>> resp = responseHeaders.get(i);
if (resp == null || resp.isEmpty()) {
continue;
}
for (Map.Entry<String, List<String>> e : resp.entrySet()) {
// Use LinkedHashSet to retain the order of the values
merged.computeIfAbsent(e.getKey(), k -> new LinkedHashSet<>(e.getValue().size())).addAll(e.getValue());
}
}
for (Map.Entry<String, Set<String>> e : merged.entrySet()) {
for (String v : e.getValue()) {
threadContext.addResponseHeader(e.getKey(), v);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public DriverTaskRunner(TransportService transportService, Executor executor) {
}

public void executeDrivers(Task parentTask, List<Driver> drivers, Executor executor, ActionListener<Void> listener) {
var runner = new DriverRunner() {
var runner = new DriverRunner(transportService.getThreadPool().getThreadContext()) {
@Override
protected void start(Driver driver, ActionListener<Void> driverListener) {
transportService.sendChildRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public final void testManyInitialManyPartialFinalRunner() {
List<Page> results = new ArrayList<>();

List<Driver> drivers = createDriversForInput(bigArrays, input, results, false /* no throwing ops */);
var runner = new DriverRunner() {
var runner = new DriverRunner(threadPool.getThreadContext()) {
@Override
protected void start(Driver driver, ActionListener<Void> listener) {
Driver.start(threadPool.executor(ESQL_TEST_EXECUTOR), driver, between(1, 10000), listener);
Expand All @@ -182,7 +182,7 @@ public final void testManyInitialManyPartialFinalRunnerThrowing() {
List<Page> results = new ArrayList<>();

List<Driver> drivers = createDriversForInput(bigArrays, input, results, true /* one throwing op */);
var runner = new DriverRunner() {
var runner = new DriverRunner(threadPool.getThreadContext()) {
@Override
protected void start(Driver driver, ActionListener<Void> listener) {
Driver.start(threadPool.executor(ESQL_TEST_EXECUTOR), driver, between(1, 1000), listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public static void runDriver(List<Driver> drivers) {
getTestClass().getSimpleName(),
new FixedExecutorBuilder(Settings.EMPTY, "esql", numThreads, 1024, "esql", EsExecutors.TaskTrackingConfig.DEFAULT)
);
var driverRunner = new DriverRunner() {
var driverRunner = new DriverRunner(threadPool.getThreadContext()) {
@Override
protected void start(Driver driver, ActionListener<Void> driverListener) {
Driver.start(threadPool.executor("esql"), driver, between(1, 10000), driverListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ void runConcurrentTest(
drivers.add(d);
}
PlainActionFuture<Void> future = new PlainActionFuture<>();
new DriverRunner() {
new DriverRunner(threadPool.getThreadContext()) {
@Override
protected void start(Driver driver, ActionListener<Void> listener) {
Driver.start(threadPool.executor(ESQL_TEST_EXECUTOR), driver, between(1, 10000), listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public void testSimple() {

DateFormatter dateFmt = DateFormatter.forPattern("yyyy-MM-dd");

var runner = new DriverRunner() {
var runner = new DriverRunner(transportService.getThreadPool().getThreadContext()) {
final Executor executor = transportService.getThreadPool().executor(EsqlPlugin.ESQL_THREAD_POOL_NAME);

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ private ActualResults executePlan() throws Exception {
Randomness.shuffle(drivers);
}
// Execute the driver
DriverRunner runner = new DriverRunner() {
DriverRunner runner = new DriverRunner(threadPool.getThreadContext()) {
@Override
protected void start(Driver driver, ActionListener<Void> driverListener) {
Driver.start(threadPool.executor(ESQL_THREAD_POOL_NAME), driver, between(1, 1000), driverListener);
Expand Down

0 comments on commit 6c40a96

Please sign in to comment.