From 0cb27e0e6ec28158142703f0097704fe051ddce8 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 12 Oct 2023 16:18:25 -0400 Subject: [PATCH] Using a custom method for sending requests in the queue --- .../sender/HttpRequestExecutorService.java | 39 +++++---- .../http/sender/HttpRequestSenderFactory.java | 7 +- .../external/http/sender/HttpTask.java | 9 +- .../external/http/sender/RequestTask.java | 9 +- .../external/http/sender/ShutdownTask.java | 2 +- .../HttpRequestExecutorServiceTests.java | 87 ++++++++++--------- .../sender/HttpRequestSenderFactoryTests.java | 4 +- .../http/sender/RequestTaskTests.java | 15 +--- 8 files changed, 85 insertions(+), 87 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorService.java index 572d550177dc3..33242cb9f4298 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorService.java @@ -7,14 +7,18 @@ package org.elasticsearch.xpack.inference.external.http.sender; +import org.apache.http.client.methods.HttpRequestBase; import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.xpack.inference.external.http.HttpClient; +import org.elasticsearch.xpack.inference.external.http.HttpResult; import java.util.ArrayList; import java.util.List; @@ -40,7 +44,7 @@ * attempting to execute a task (aka waiting for the connection manager to lease a connection). See * {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info. */ -public class HttpRequestExecutorService extends AbstractExecutorService { +class HttpRequestExecutorService extends AbstractExecutorService { private static final Logger logger = LogManager.getLogger(HttpRequestExecutorService.class); private final ThreadContext contextHolder; @@ -49,16 +53,18 @@ public class HttpRequestExecutorService extends AbstractExecutorService { private final AtomicBoolean running = new AtomicBoolean(true); private final CountDownLatch terminationLatch = new CountDownLatch(1); private final HttpClientContext httpContext; + private final HttpClient httpClient; @SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors") - public HttpRequestExecutorService(ThreadContext contextHolder, String serviceName) { - this(contextHolder, serviceName, null); + HttpRequestExecutorService(ThreadContext contextHolder, String serviceName, HttpClient httpClient) { + this(contextHolder, serviceName, httpClient, null); } @SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors") - public HttpRequestExecutorService(ThreadContext contextHolder, String serviceName, @Nullable Integer capacity) { + HttpRequestExecutorService(ThreadContext contextHolder, String serviceName, HttpClient httpClient, @Nullable Integer capacity) { this.contextHolder = Objects.requireNonNull(contextHolder); this.serviceName = Objects.requireNonNull(serviceName); + this.httpClient = Objects.requireNonNull(httpClient); this.httpContext = HttpClientContext.create(); if (capacity == null) { @@ -132,18 +138,12 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE } /** - * Execute the task at some point in the future. - * @param command the runnable task, must be a class that extends {@link HttpTask} + * Send the request at some point in the future. + * @param request the http request to send + * @param listener an {@link ActionListener} for the response or failure */ - @Override - public void execute(Runnable command) { - if (command == null) { - return; - } - - assert command instanceof HttpTask; - HttpTask task = (HttpTask) command; - task.setContext(httpContext); + public void send(HttpRequestBase request, ActionListener listener) { + RequestTask task = new RequestTask(request, httpClient, httpContext, listener); if (isShutdown()) { EsRejectedExecutionException rejected = new EsRejectedExecutionException( @@ -165,4 +165,13 @@ public void execute(Runnable command) { task.onRejection(rejected); } } + + /** + * This method is not supported. Use {@link #send} instead. + * @param runnable the runnable task + */ + @Override + public void execute(Runnable runnable) { + throw new UnsupportedOperationException("use send instead"); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactory.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactory.java index 535d1034bc5bc..062f684acfec8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactory.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactory.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.http.sender; +import org.apache.http.client.methods.HttpRequestBase; import org.apache.http.client.methods.HttpUriRequest; import org.elasticsearch.action.ActionListener; import org.elasticsearch.threadpool.ThreadPool; @@ -49,7 +50,7 @@ public static final class HttpRequestSender implements Closeable { private HttpRequestSender(String serviceName, ThreadPool threadPool, HttpClientManager httpClientManager) { this.threadPool = Objects.requireNonNull(threadPool); this.manager = Objects.requireNonNull(httpClientManager); - service = new HttpRequestExecutorService(threadPool.getThreadContext(), serviceName); + service = new HttpRequestExecutorService(threadPool.getThreadContext(), serviceName, manager.getHttpClient()); } /** @@ -68,9 +69,9 @@ public void close() throws IOException { service.shutdown(); } - public void send(HttpUriRequest request, ActionListener listener) { + public void send(HttpRequestBase request, ActionListener listener) { assert started.get() : "call start() before sending a request"; - service.execute(new RequestTask(request, manager.getHttpClient(), listener)); + service.send(request, listener); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpTask.java index 67c573c1bb79d..6881d75524bda 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpTask.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpTask.java @@ -7,17 +7,10 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.elasticsearch.common.util.concurrent.AbstractRunnable; -public abstract class HttpTask extends AbstractRunnable { - protected HttpClientContext context; - +abstract class HttpTask extends AbstractRunnable { public boolean shouldShutdown() { return false; } - - public void setContext(HttpClientContext context) { - this.context = context; - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java index 4d3665c217a56..2a1fe653cb8cd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; @@ -20,17 +21,19 @@ import static org.elasticsearch.core.Strings.format; -public class RequestTask extends HttpTask { +class RequestTask extends HttpTask { private static final Logger logger = LogManager.getLogger(RequestTask.class); private final HttpUriRequest request; private final ActionListener listener; private final HttpClient httpClient; + private final HttpClientContext context; - public RequestTask(HttpUriRequest request, HttpClient httpClient, ActionListener listener) { + RequestTask(HttpUriRequest request, HttpClient httpClient, HttpClientContext context, ActionListener listener) { this.request = Objects.requireNonNull(request); this.httpClient = Objects.requireNonNull(httpClient); this.listener = Objects.requireNonNull(listener); + this.context = Objects.requireNonNull(context); } @Override @@ -40,8 +43,6 @@ public void onFailure(Exception e) { @Override protected void doRun() throws Exception { - assert context != null : "the http context must be set before calling doRun"; - try { httpClient.send(request, context, listener); } catch (IOException e) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ShutdownTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ShutdownTask.java index 40956d96a47e6..9ec2edf514e80 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ShutdownTask.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ShutdownTask.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.inference.external.http.sender; -public class ShutdownTask extends HttpTask { +class ShutdownTask extends HttpTask { @Override public boolean shouldShutdown() { return true; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorServiceTests.java index 4f77e62601eea..f87bca12a7f39 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorServiceTests.java @@ -7,10 +7,11 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.client.methods.HttpRequestBase; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -18,7 +19,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.junit.After; import org.junit.Before; -import org.mockito.ArgumentCaptor; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Future; @@ -29,8 +29,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; public class HttpRequestExecutorServiceTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); @@ -47,28 +45,34 @@ public void shutdown() { } public void testQueueSize_IsEmpty() { - var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName()); + var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mock(HttpClient.class)); assertThat(service.queueSize(), is(0)); } public void testQueueSize_IsOne() { - var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName()); - var noopTask = mock(RequestTask.class); - - service.execute(noopTask); + var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mock(HttpClient.class)); + service.send(mock(HttpRequestBase.class), new PlainActionFuture<>()); assertThat(service.queueSize(), is(1)); } + public void testExecute_ThrowsUnsupported() { + var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mock(HttpClient.class)); + var noopTask = mock(RequestTask.class); + + var thrownException = expectThrows(UnsupportedOperationException.class, () -> service.execute(noopTask)); + assertThat(thrownException.getMessage(), is("use send instead")); + } + public void testIsTerminated_IsFalse() { - var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName()); + var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mock(HttpClient.class)); assertFalse(service.isTerminated()); } public void testIsTerminated_IsTrue() { - var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName()); + var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mock(HttpClient.class)); service.shutdown(); service.start(); @@ -77,9 +81,16 @@ public void testIsTerminated_IsTrue() { } public void testIsTerminated_AfterStopFromSeparateThread() throws Exception { - var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName()); - var waitToShutdown = new CountDownLatch(1); + + var mockHttpClient = mock(HttpClient.class); + doAnswer(invocation -> { + waitToShutdown.countDown(); + return Void.TYPE; + }).when(mockHttpClient).send(any(), any(), any()); + + var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mockHttpClient); + Future executorTermination = threadPool.generic().submit(() -> { try { // wait for a task to be added to be executed before beginning shutdown @@ -91,13 +102,8 @@ public void testIsTerminated_AfterStopFromSeparateThread() throws Exception { } }); - var blockingTask = mock(RequestTask.class); - doAnswer(invocation -> { - waitToShutdown.countDown(); - return Void.TYPE; - }).when(blockingTask).doRun(); - - service.execute(blockingTask); + PlainActionFuture listener = new PlainActionFuture<>(); + service.send(mock(HttpRequestBase.class), listener); service.start(); @@ -112,50 +118,49 @@ public void testIsTerminated_AfterStopFromSeparateThread() throws Exception { } public void testExecute_AfterShutdown_Throws() { - var service = new HttpRequestExecutorService(threadPool.getThreadContext(), "test_service"); + var service = new HttpRequestExecutorService(threadPool.getThreadContext(), "test_service", mock(HttpClient.class)); service.shutdown(); - var task = mock(RequestTask.class); - service.execute(task); + var listener = new PlainActionFuture(); + service.send(mock(HttpRequestBase.class), listener); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); - verify(task, times(1)).onRejection(exceptionCaptor.capture()); + var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); assertThat( - exceptionCaptor.getValue().getMessage(), + thrownException.getMessage(), is("Failed to execute task because the http executor service [test_service] has shutdown") ); } public void testExecute_Throws_WhenQueueIsFull() { - var service = new HttpRequestExecutorService(threadPool.getThreadContext(), "test_service", 1); + var service = new HttpRequestExecutorService(threadPool.getThreadContext(), "test_service", mock(HttpClient.class), 1); - var task = mock(RequestTask.class); - var rejectedTask = mock(RequestTask.class); - service.execute(task); - service.execute(rejectedTask); + service.send(mock(HttpRequestBase.class), new PlainActionFuture<>()); + var listener = new PlainActionFuture(); + service.send(mock(HttpRequestBase.class), listener); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); - verify(rejectedTask, times(1)).onRejection(exceptionCaptor.capture()); + var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); assertThat( - exceptionCaptor.getValue().getMessage(), + thrownException.getMessage(), is("Failed to execute task because the http executor service [test_service] queue is full") ); } public void testTaskThrowsError_CallsOnFailure() throws Exception { - var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName()); - var httpClient = mock(HttpClient.class); - doAnswer(invocation -> { throw new ElasticsearchException("failed"); }).when(httpClient).send(any(), any(), any()); + + var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), httpClient); + + doAnswer(invocation -> { + service.shutdown(); + throw new ElasticsearchException("failed"); + }).when(httpClient).send(any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); - var failingTask = new RequestTask(mock(HttpUriRequest.class), httpClient, listener); - service.execute(failingTask); - service.execute(new ShutdownTask()); + service.send(mock(HttpRequestBase.class), listener); service.start(); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -164,7 +169,7 @@ public void testTaskThrowsError_CallsOnFailure() throws Exception { } public void testShutdown_AllowsMultipleCalls() { - var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName()); + var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mock(HttpClient.class)); service.shutdown(); service.shutdown(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactoryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactoryTests.java index b0df1a934d399..47ecde6ae535c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactoryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactoryTests.java @@ -8,7 +8,7 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.apache.http.HttpHeaders; -import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.client.methods.HttpRequestBase; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -112,7 +112,7 @@ public void testHttpRequestSender_Throws_WhenCallingSendBeforeStart() throws Exc try (var sender = senderFactory.createSender("test_service")) { PlainActionFuture listener = new PlainActionFuture<>(); - var thrownException = expectThrows(AssertionError.class, () -> sender.send(mock(HttpUriRequest.class), listener)); + var thrownException = expectThrows(AssertionError.class, () -> sender.send(mock(HttpRequestBase.class), listener)); assertThat(thrownException.getMessage(), is("call start() before sending a request")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java index 447df0f7236e7..37f4fb8cce4cb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.apache.http.HttpHeaders; -import org.apache.http.client.methods.HttpUriRequest; import org.apache.http.client.protocol.HttpClientContext; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.support.PlainActionFuture; @@ -69,8 +68,7 @@ public void testDoRun_SendsRequestAndReceivesResponse() throws Exception { httpClient.start(); PlainActionFuture listener = new PlainActionFuture<>(); - var requestTask = new RequestTask(httpPost, httpClient, listener); - requestTask.setContext(HttpClientContext.create()); + var requestTask = new RequestTask(httpPost, httpClient, HttpClientContext.create(), listener); requestTask.doRun(); var result = listener.actionGet(TIMEOUT); @@ -83,14 +81,6 @@ public void testDoRun_SendsRequestAndReceivesResponse() throws Exception { } } - public void testDoRun_Throws_WhenContextIsNotSet() { - PlainActionFuture listener = new PlainActionFuture<>(); - - var requestTask = new RequestTask(mock(HttpUriRequest.class), mock(HttpClient.class), listener); - var thrownException = expectThrows(AssertionError.class, requestTask::doRun); - assertThat(thrownException.getMessage(), is("the http context must be set before calling doRun")); - } - public void testDoRun_SendThrowsIOException() throws Exception { var httpClient = mock(HttpClient.class); doThrow(new IOException("exception")).when(httpClient).send(any(), any(), any()); @@ -100,8 +90,7 @@ public void testDoRun_SendThrowsIOException() throws Exception { var httpPost = createHttpPost(webServer.getPort(), paramKey, paramValue); PlainActionFuture listener = new PlainActionFuture<>(); - var requestTask = new RequestTask(httpPost, httpClient, listener); - requestTask.setContext(HttpClientContext.create()); + var requestTask = new RequestTask(httpPost, httpClient, HttpClientContext.create(), listener); requestTask.doRun(); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));