diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java index c687ce7f864f1..788fc2887ebd9 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java @@ -9,17 +9,29 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.CountDown; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Releasables; import org.elasticsearch.tasks.TaskCancelledException; +import java.util.HashMap; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicReference; /** * Run a set of drivers to completion. */ public abstract class DriverRunner { + private final ThreadContext threadContext; + + public DriverRunner(ThreadContext threadContext) { + this.threadContext = threadContext; + } + /** * Start a driver. */ @@ -30,8 +42,11 @@ public abstract class DriverRunner { */ public void runToCompletion(List drivers, ActionListener listener) { AtomicReference failure = new AtomicReference<>(); + AtomicArray>> responseHeaders = new AtomicArray<>(drivers.size()); CountDown counter = new CountDown(drivers.size()); - for (Driver driver : drivers) { + for (int i = 0; i < drivers.size(); i++) { + Driver driver = drivers.get(i); + int driverIndex = i; ActionListener driverListener = new ActionListener<>() { @Override public void onResponse(Void unused) { @@ -66,7 +81,9 @@ public void onFailure(Exception e) { } private void done() { + responseHeaders.setOnce(driverIndex, threadContext.getResponseHeaders()); if (counter.countDown()) { + mergeResponseHeaders(responseHeaders); for (Driver d : drivers) { if (d.status().status() == DriverStatus.Status.QUEUED) { d.close(); @@ -87,4 +104,23 @@ private void done() { start(driver, driverListener); } } + + private void mergeResponseHeaders(AtomicArray>> responseHeaders) { + final Map> merged = new HashMap<>(); + for (int i = 0; i < responseHeaders.length(); i++) { + final Map> resp = responseHeaders.get(i); + if (resp == null || resp.isEmpty()) { + continue; + } + for (Map.Entry> e : resp.entrySet()) { + // Use LinkedHashSet to retain the order of the values + merged.computeIfAbsent(e.getKey(), k -> new LinkedHashSet<>(e.getValue().size())).addAll(e.getValue()); + } + } + for (Map.Entry> e : merged.entrySet()) { + for (String v : e.getValue()) { + threadContext.addResponseHeader(e.getKey(), v); + } + } + } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverTaskRunner.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverTaskRunner.java index 53d5a66de7b66..221be19cc2871 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverTaskRunner.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverTaskRunner.java @@ -42,7 +42,7 @@ public DriverTaskRunner(TransportService transportService, Executor executor) { } public void executeDrivers(Task parentTask, List drivers, Executor executor, ActionListener listener) { - var runner = new DriverRunner() { + var runner = new DriverRunner(transportService.getThreadPool().getThreadContext()) { @Override protected void start(Driver driver, ActionListener driverListener) { transportService.sendChildRequest( diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java index 1c12fbf4bcd52..9d1084fcc4cf3 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java @@ -160,7 +160,7 @@ public final void testManyInitialManyPartialFinalRunner() { List results = new ArrayList<>(); List drivers = createDriversForInput(bigArrays, input, results, false /* no throwing ops */); - var runner = new DriverRunner() { + var runner = new DriverRunner(threadPool.getThreadContext()) { @Override protected void start(Driver driver, ActionListener listener) { Driver.start(threadPool.executor(ESQL_TEST_EXECUTOR), driver, between(1, 10000), listener); @@ -182,7 +182,7 @@ public final void testManyInitialManyPartialFinalRunnerThrowing() { List results = new ArrayList<>(); List drivers = createDriversForInput(bigArrays, input, results, true /* one throwing op */); - var runner = new DriverRunner() { + var runner = new DriverRunner(threadPool.getThreadContext()) { @Override protected void start(Driver driver, ActionListener listener) { Driver.start(threadPool.executor(ESQL_TEST_EXECUTOR), driver, between(1, 1000), listener); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/OperatorTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/OperatorTestCase.java index 3cbab148e3073..3b2fac5271aa6 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/OperatorTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/OperatorTestCase.java @@ -194,7 +194,7 @@ public static void runDriver(List drivers) { getTestClass().getSimpleName(), new FixedExecutorBuilder(Settings.EMPTY, "esql", numThreads, 1024, "esql", EsExecutors.TaskTrackingConfig.DEFAULT) ); - var driverRunner = new DriverRunner() { + var driverRunner = new DriverRunner(threadPool.getThreadContext()) { @Override protected void start(Driver driver, ActionListener driverListener) { Driver.start(threadPool.executor("esql"), driver, between(1, 10000), driverListener); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java index 5b6b33ea0b80a..c94320a9d406a 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java @@ -301,7 +301,7 @@ void runConcurrentTest( drivers.add(d); } PlainActionFuture future = new PlainActionFuture<>(); - new DriverRunner() { + new DriverRunner(threadPool.getThreadContext()) { @Override protected void start(Driver driver, ActionListener listener) { Driver.start(threadPool.executor(ESQL_TEST_EXECUTOR), driver, between(1, 10000), listener); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/lookup/EnrichLookupIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/lookup/EnrichLookupIT.java index d6611881f8546..fa5a1617e9d61 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/lookup/EnrichLookupIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/lookup/EnrichLookupIT.java @@ -141,7 +141,7 @@ public void testSimple() { DateFormatter dateFmt = DateFormatter.forPattern("yyyy-MM-dd"); - var runner = new DriverRunner() { + var runner = new DriverRunner(transportService.getThreadPool().getThreadContext()) { final Executor executor = transportService.getThreadPool().executor(EsqlPlugin.ESQL_THREAD_POOL_NAME); @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 0f1d9257a6c0a..8a5b021addae5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -378,7 +378,7 @@ private ActualResults executePlan() throws Exception { Randomness.shuffle(drivers); } // Execute the driver - DriverRunner runner = new DriverRunner() { + DriverRunner runner = new DriverRunner(threadPool.getThreadContext()) { @Override protected void start(Driver driver, ActionListener driverListener) { Driver.start(threadPool.executor(ESQL_THREAD_POOL_NAME), driver, between(1, 1000), driverListener);