From e26ad8f9d21700b79850fdbf188310224501f0c2 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Mon, 16 Oct 2023 14:15:02 -0400 Subject: [PATCH] [ML] Adding request queuing for http requests (#100674) * Tests are really slow * Closing services * Cleaning up code * Fixing spotless * Adding some logging for evictor thread * Using a custom method for sending requests in the queue * Adding timeout and rejection logic * Fixing merge failure * Revert "Adding timeout and rejection logic" This reverts commit acc8ba0c0b2ece48a7416a87cf9724fdb1408fd4. * Removing rethrow * Reverting node.java changes --------- Co-authored-by: Elastic Machine --- .../inference/InferenceService.java | 3 +- .../TestInferenceServicePlugin.java | 3 + .../xpack/inference/InferencePlugin.java | 27 +-- .../inference/external/http/HttpClient.java | 19 +- .../external/http/HttpClientManager.java | 25 ++- .../external/http/IdleConnectionEvictor.java | 26 ++- .../sender/HttpRequestExecutorService.java | 172 +++++++++++++++++ .../http/sender/HttpRequestSenderFactory.java | 77 ++++++++ .../external/http/sender/HttpTask.java | 16 ++ .../external/http/sender/RequestTask.java | 58 ++++++ .../external/http/sender/ShutdownTask.java | 21 ++ .../services/elser/ElserMlNodeService.java | 4 + .../external/http/HttpClientManagerTests.java | 23 +-- .../external/http/HttpClientTests.java | 47 +++-- .../http/IdleConnectionEvictorTests.java | 24 +-- .../HttpRequestExecutorServiceTests.java | 182 ++++++++++++++++++ .../sender/HttpRequestSenderFactoryTests.java | 119 ++++++++++++ .../http/sender/RequestTaskTests.java | 99 ++++++++++ 18 files changed, 837 insertions(+), 108 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorService.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactory.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpTask.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ShutdownTask.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorServiceTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactoryTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 8d8bc08f5b0fa..82ce13e591b6c 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -10,10 +10,11 @@ import org.elasticsearch.action.ActionListener; +import java.io.Closeable; import java.util.Map; import java.util.Set; -public interface InferenceService { +public interface InferenceService extends Closeable { String name(); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java index 96625e6bec031..102436b37524c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java @@ -161,6 +161,9 @@ public void infer(Model model, String input, Map taskSettings, A public void start(Model model, ActionListener listener) { listener.onResponse(true); } + + @Override + public void close() throws IOException {} } public static class TestServiceModel extends Model { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index cc84a5c53c81c..f8232b2572b47 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -52,6 +52,7 @@ import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpSettings; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; @@ -69,9 +70,10 @@ public class InferencePlugin extends Plugin implements ActionPlugin, InferenceSe public static final String NAME = "inference"; public static final String UTILITY_THREAD_POOL_NAME = "inference_utility"; - public static final String HTTP_CLIENT_SENDER_THREAD_POOL_NAME = "inference_http_client_sender"; private final Settings settings; - private final SetOnce httpClientManager = new SetOnce<>(); + private final SetOnce httpRequestSenderFactory = new SetOnce<>(); + // We'll keep a reference to the http manager just in case the inference services don't get closed individually + private final SetOnce httpManager = new SetOnce<>(); public InferencePlugin(Settings settings) { this.settings = settings; @@ -122,8 +124,8 @@ public Collection createComponents( AllocationService allocationService, IndicesService indicesService ) { - httpClientManager.set(HttpClientManager.create(settings, threadPool, clusterService)); - + httpManager.set(HttpClientManager.create(settings, threadPool, clusterService)); + httpRequestSenderFactory.set(new HttpRequestSenderFactory(threadPool, httpManager.get())); ModelRegistry modelRegistry = new ModelRegistry(client); return List.of(modelRegistry); } @@ -165,19 +167,6 @@ public List> getExecutorBuilders(Settings settingsToUse) { TimeValue.timeValueMinutes(10), false, "xpack.inference.utility_thread_pool" - ), - /* - * This executor is specifically for enqueuing requests to be sent. The underlying - * connection pool used by the http client will block if there are no available connections to lease. - * See here for more info: https://hc.apache.org/httpcomponents-client-4.5.x/current/tutorial/html/connmgmt.html - */ - new ScalingExecutorBuilder( - HTTP_CLIENT_SENDER_THREAD_POOL_NAME, - 0, - 1, - TimeValue.timeValueMinutes(10), - false, - "xpack.inference.http_client_sender_thread_pool" ) ); } @@ -209,8 +198,8 @@ public List getInferenceServiceNamedWriteables() { @Override public void close() { - if (httpClientManager.get() != null) { - IOUtils.closeWhileHandlingException(httpClientManager.get()); + if (httpManager.get() != null) { + IOUtils.closeWhileHandlingException(httpManager.get()); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java index 5622ac51ba187..2de7ae8442ee6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java @@ -9,13 +9,13 @@ import org.apache.http.HttpResponse; import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.client.protocol.HttpClientContext; import org.apache.http.concurrent.FutureCallback; import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; import org.apache.http.impl.nio.conn.PoolingNHttpClientConnectionManager; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.common.socket.SocketAccess; @@ -26,7 +26,6 @@ import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.InferencePlugin.HTTP_CLIENT_SENDER_THREAD_POOL_NAME; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; public class HttpClient implements Closeable { @@ -72,21 +71,11 @@ public void start() { } } - public void send(HttpUriRequest request, ActionListener listener) { + public void send(HttpUriRequest request, HttpClientContext context, ActionListener listener) throws IOException { // The caller must call start() first before attempting to send a request - assert status.get() == Status.STARTED; + assert status.get() == Status.STARTED : "call start() before attempting to send a request"; - threadPool.executor(HTTP_CLIENT_SENDER_THREAD_POOL_NAME).execute(() -> { - try { - doPrivilegedSend(request, listener); - } catch (IOException e) { - listener.onFailure(new ElasticsearchException(format("Failed to send request [%s]", request.getRequestLine()), e)); - } - }); - } - - private void doPrivilegedSend(HttpUriRequest request, ActionListener listener) throws IOException { - SocketAccess.doPrivileged(() -> client.execute(request, new FutureCallback<>() { + SocketAccess.doPrivileged(() -> client.execute(request, context, new FutureCallback<>() { @Override public void completed(HttpResponse response) { respondUsingUtilityThread(response, request, listener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java index 860caec3e8019..862170a229b41 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java @@ -24,6 +24,8 @@ import java.io.IOException; import java.util.List; +import static org.elasticsearch.core.Strings.format; + public class HttpClientManager implements Closeable { private static final Logger logger = LogManager.getLogger(HttpClientManager.class); /** @@ -31,7 +33,7 @@ public class HttpClientManager implements Closeable { * * https://stackoverflow.com/questions/30989637/how-to-decide-optimal-settings-for-setmaxtotal-and-setdefaultmaxperroute */ - static final Setting MAX_CONNECTIONS = Setting.intSetting( + public static final Setting MAX_CONNECTIONS = Setting.intSetting( "xpack.inference.http.max_connections", // TODO pick a reasonable values here 20, @@ -42,7 +44,7 @@ public class HttpClientManager implements Closeable { ); private static final TimeValue DEFAULT_CONNECTION_EVICTION_THREAD_INTERVAL_TIME = TimeValue.timeValueSeconds(10); - static final Setting CONNECTION_EVICTION_THREAD_INTERVAL_SETTING = Setting.timeSetting( + public static final Setting CONNECTION_EVICTION_THREAD_INTERVAL_SETTING = Setting.timeSetting( "xpack.inference.http.connection_eviction_interval", DEFAULT_CONNECTION_EVICTION_THREAD_INTERVAL_TIME, Setting.Property.NodeScope, @@ -50,7 +52,7 @@ public class HttpClientManager implements Closeable { ); private static final TimeValue DEFAULT_CONNECTION_EVICTION_MAX_IDLE_TIME_SETTING = DEFAULT_CONNECTION_EVICTION_THREAD_INTERVAL_TIME; - static final Setting CONNECTION_EVICTION_MAX_IDLE_TIME_SETTING = Setting.timeSetting( + public static final Setting CONNECTION_EVICTION_MAX_IDLE_TIME_SETTING = Setting.timeSetting( "xpack.inference.http.connection_eviction_max_idle_time", DEFAULT_CONNECTION_EVICTION_MAX_IDLE_TIME_SETTING, Setting.Property.NodeScope, @@ -128,7 +130,7 @@ public HttpClient getHttpClient() { @Override public void close() throws IOException { httpClient.close(); - connectionEvictor.stop(); + connectionEvictor.close(); } private void setMaxConnections(int maxConnections) { @@ -136,21 +138,26 @@ private void setMaxConnections(int maxConnections) { connectionManager.setDefaultMaxPerRoute(maxConnections); } + // This is only used for testing + boolean isEvictionThreadRunning() { + return connectionEvictor.isRunning(); + } + // default for testing void setEvictionInterval(TimeValue evictionInterval) { + logger.debug(() -> format("Eviction thread's interval time updated to [%s]", evictionInterval)); + evictorSettings = new EvictorSettings(evictionInterval, evictorSettings.evictionMaxIdle); - connectionEvictor.stop(); + connectionEvictor.close(); connectionEvictor = createConnectionEvictor(); connectionEvictor.start(); } void setEvictionMaxIdle(TimeValue evictionMaxIdle) { + logger.debug(() -> format("Eviction thread's max idle time updated to [%s]", evictionMaxIdle)); evictorSettings = new EvictorSettings(evictorSettings.evictionInterval, evictionMaxIdle); - - connectionEvictor.stop(); - connectionEvictor = createConnectionEvictor(); - connectionEvictor.start(); + connectionEvictor.setMaxIdleTime(evictionMaxIdle); } private static class EvictorSettings { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/IdleConnectionEvictor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/IdleConnectionEvictor.java index f326661adc6f4..295c9b7b17946 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/IdleConnectionEvictor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/IdleConnectionEvictor.java @@ -14,10 +14,12 @@ import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.ThreadPool; +import java.io.Closeable; import java.util.Objects; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; /** @@ -30,13 +32,13 @@ * * See here for more info. */ -public class IdleConnectionEvictor { +public class IdleConnectionEvictor implements Closeable { private static final Logger logger = LogManager.getLogger(IdleConnectionEvictor.class); private final ThreadPool threadPool; private final NHttpClientConnectionManager connectionManager; private final TimeValue sleepTime; - private final TimeValue maxIdleTime; + private final AtomicReference maxIdleTime = new AtomicReference<>(); private final AtomicReference cancellableTask = new AtomicReference<>(); public IdleConnectionEvictor( @@ -45,10 +47,14 @@ public IdleConnectionEvictor( TimeValue sleepTime, TimeValue maxIdleTime ) { - this.threadPool = threadPool; + this.threadPool = Objects.requireNonNull(threadPool); this.connectionManager = Objects.requireNonNull(connectionManager); - this.sleepTime = sleepTime; - this.maxIdleTime = maxIdleTime; + this.sleepTime = Objects.requireNonNull(sleepTime); + this.maxIdleTime.set(maxIdleTime); + } + + public void setMaxIdleTime(TimeValue maxIdleTime) { + this.maxIdleTime.set(maxIdleTime); } public synchronized void start() { @@ -58,11 +64,13 @@ public synchronized void start() { } private void startInternal() { + logger.debug(() -> format("Idle connection evictor started with wait time: [%s] max idle: [%s]", sleepTime, maxIdleTime)); + Scheduler.Cancellable task = threadPool.scheduleWithFixedDelay(() -> { try { connectionManager.closeExpiredConnections(); - if (maxIdleTime != null) { - connectionManager.closeIdleConnections(maxIdleTime.millis(), TimeUnit.MILLISECONDS); + if (maxIdleTime.get() != null) { + connectionManager.closeIdleConnections(maxIdleTime.get().millis(), TimeUnit.MILLISECONDS); } } catch (Exception e) { logger.warn("HTTP connection eviction failed", e); @@ -72,8 +80,10 @@ private void startInternal() { cancellableTask.set(task); } - public void stop() { + @Override + public void close() { if (cancellableTask.get() != null) { + logger.debug("Idle connection evictor closing"); cancellableTask.get().cancel(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorService.java new file mode 100644 index 0000000000000..86065f35fd882 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorService.java @@ -0,0 +1,172 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.xpack.inference.external.http.HttpClient; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.AbstractExecutorService; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.elasticsearch.core.Strings.format; + +/** + * An {@link java.util.concurrent.ExecutorService} for queuing and executing {@link RequestTask} containing + * {@link org.apache.http.client.methods.HttpUriRequest}. This class is useful because the + * {@link org.apache.http.impl.nio.conn.PoolingNHttpClientConnectionManager} will block when leasing a connection if no + * connections are available. To avoid blocking the inference transport threads, this executor will queue up the + * requests until connections are available. + * + * NOTE: It is the responsibility of the class constructing the + * {@link org.apache.http.client.methods.HttpUriRequest} to set a timeout for how long this executor will wait + * attempting to execute a task (aka waiting for the connection manager to lease a connection). See + * {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info. + */ +class HttpRequestExecutorService extends AbstractExecutorService { + private static final Logger logger = LogManager.getLogger(HttpRequestExecutorService.class); + + private final String serviceName; + private final BlockingQueue queue; + private final AtomicBoolean running = new AtomicBoolean(true); + private final CountDownLatch terminationLatch = new CountDownLatch(1); + private final HttpClientContext httpContext; + private final HttpClient httpClient; + + @SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors") + HttpRequestExecutorService(String serviceName, HttpClient httpClient) { + this(serviceName, httpClient, null); + } + + @SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors") + HttpRequestExecutorService(String serviceName, HttpClient httpClient, @Nullable Integer capacity) { + this.serviceName = Objects.requireNonNull(serviceName); + this.httpClient = Objects.requireNonNull(httpClient); + this.httpContext = HttpClientContext.create(); + + if (capacity == null) { + this.queue = new LinkedBlockingQueue<>(); + } else { + this.queue = new LinkedBlockingQueue<>(capacity); + } + } + + /** + * Begin servicing tasks. + */ + public void start() { + try { + while (running.get()) { + HttpTask task = queue.take(); + if (task.shouldShutdown() || running.get() == false) { + running.set(false); + logger.debug(() -> format("Http executor service [%s] exiting", serviceName)); + } else { + executeTask(task); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + terminationLatch.countDown(); + } + } + + private void executeTask(HttpTask task) { + try { + task.run(); + } catch (Exception e) { + logger.error(format("Http executor service [%s] failed to execute request [%s]", serviceName, task), e); + } + } + + public int queueSize() { + return queue.size(); + } + + @Override + public void shutdown() { + if (running.compareAndSet(true, false)) { + // if this fails because the queue is full, that's ok, we just want to ensure that queue.take() returns + queue.offer(new ShutdownTask()); + } + } + + @Override + public List shutdownNow() { + shutdown(); + return new ArrayList<>(queue); + } + + @Override + public boolean isShutdown() { + return running.get() == false; + } + + @Override + public boolean isTerminated() { + return terminationLatch.getCount() == 0; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return terminationLatch.await(timeout, unit); + } + + /** + * Send the request at some point in the future. + * @param request the http request to send + * @param listener an {@link ActionListener} for the response or failure + */ + public void send(HttpRequestBase request, ActionListener listener) { + RequestTask task = new RequestTask(request, httpClient, httpContext, listener); + + if (isShutdown()) { + EsRejectedExecutionException rejected = new EsRejectedExecutionException( + format("Failed to execute task because the http executor service [%s] has shutdown", serviceName), + true + ); + + task.onRejection(rejected); + return; + } + + boolean added = queue.offer(task); + if (added == false) { + EsRejectedExecutionException rejected = new EsRejectedExecutionException( + format("Failed to execute task because the http executor service [%s] queue is full", serviceName), + false + ); + + task.onRejection(rejected); + } + } + + /** + * This method is not supported. Use {@link #send} instead. + * @param runnable the runnable task + */ + @Override + public void execute(Runnable runnable) { + throw new UnsupportedOperationException("use send instead"); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactory.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactory.java new file mode 100644 index 0000000000000..71ddb9e0849c8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactory.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.client.methods.HttpUriRequest; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; + +/** + * A helper class for constructing a {@link HttpRequestSender}. + */ +public class HttpRequestSenderFactory { + private final ThreadPool threadPool; + private final HttpClientManager httpClientManager; + + public HttpRequestSenderFactory(ThreadPool threadPool, HttpClientManager httpClientManager) { + this.threadPool = Objects.requireNonNull(threadPool); + this.httpClientManager = Objects.requireNonNull(httpClientManager); + } + + public HttpRequestSender createSender(String serviceName) { + return new HttpRequestSender(serviceName, threadPool, httpClientManager); + } + + /** + * A class for providing a more friendly interface for sending an {@link HttpUriRequest}. This leverages the queuing logic for sending + * a request. + */ + public static final class HttpRequestSender implements Closeable { + private final ThreadPool threadPool; + private final HttpClientManager manager; + private final HttpRequestExecutorService service; + private final AtomicBoolean started = new AtomicBoolean(false); + + private HttpRequestSender(String serviceName, ThreadPool threadPool, HttpClientManager httpClientManager) { + this.threadPool = Objects.requireNonNull(threadPool); + this.manager = Objects.requireNonNull(httpClientManager); + service = new HttpRequestExecutorService(serviceName, manager.getHttpClient()); + } + + /** + * Start various internal services. This is required before sending requests. + */ + public void start() { + if (started.compareAndSet(false, true)) { + manager.start(); + threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(service::start); + } + } + + @Override + public void close() throws IOException { + manager.close(); + service.shutdown(); + } + + public void send(HttpRequestBase request, ActionListener listener) { + assert started.get() : "call start() before sending a request"; + service.send(request, listener); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpTask.java new file mode 100644 index 0000000000000..6881d75524bda --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpTask.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.common.util.concurrent.AbstractRunnable; + +abstract class HttpTask extends AbstractRunnable { + public boolean shouldShutdown() { + return false; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java new file mode 100644 index 0000000000000..2a1fe653cb8cd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.inference.external.http.HttpClient; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; + +class RequestTask extends HttpTask { + private static final Logger logger = LogManager.getLogger(RequestTask.class); + + private final HttpUriRequest request; + private final ActionListener listener; + private final HttpClient httpClient; + private final HttpClientContext context; + + RequestTask(HttpUriRequest request, HttpClient httpClient, HttpClientContext context, ActionListener listener) { + this.request = Objects.requireNonNull(request); + this.httpClient = Objects.requireNonNull(httpClient); + this.listener = Objects.requireNonNull(listener); + this.context = Objects.requireNonNull(context); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + + @Override + protected void doRun() throws Exception { + try { + httpClient.send(request, context, listener); + } catch (IOException e) { + logger.error(format("Failed to send request [%s] via the http client", request.getRequestLine()), e); + listener.onFailure(new ElasticsearchException(format("Failed to send request [%s]", request.getRequestLine()), e)); + } + } + + @Override + public String toString() { + return request.getRequestLine().toString(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ShutdownTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ShutdownTask.java new file mode 100644 index 0000000000000..9ec2edf514e80 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ShutdownTask.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +class ShutdownTask extends HttpTask { + @Override + public boolean shouldShutdown() { + return true; + } + + @Override + public void onFailure(Exception e) {} + + @Override + protected void doRun() throws Exception {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java index 907c76e02f53c..48b6952bcc8af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; +import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Set; @@ -192,4 +193,7 @@ private static ElserMlNodeTaskSettings taskSettingsFromMap(TaskType taskType, Ma public String name() { return NAME; } + + @Override + public void close() throws IOException {} } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientManagerTests.java index a9bdee95de5fc..d04a0c185d2a2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientManagerTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.external.http; import org.apache.http.HttpHeaders; +import org.apache.http.client.protocol.HttpClientContext; import org.apache.http.impl.nio.conn.PoolingNHttpClientConnectionManager; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.service.ClusterService; @@ -23,7 +24,6 @@ import org.junit.Before; import java.nio.charset.StandardCharsets; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -34,11 +34,8 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -74,7 +71,7 @@ public void testSend_MockServerReceivesRequest() throws Exception { httpClient.start(); PlainActionFuture listener = new PlainActionFuture<>(); - httpClient.send(httpPost, listener); + httpClient.send(httpPost, HttpClientContext.create(), listener); var result = listener.actionGet(TIMEOUT); @@ -96,7 +93,7 @@ public void testStartsANewEvictor_WithNewEvictionInterval() { verify(threadPool).scheduleWithFixedDelay(any(Runnable.class), eq(evictionInterval), any()); } - public void testStartsANewEvictor_WithNewEvictionMaxIdle() throws InterruptedException { + public void test_DoesNotStartANewEvictor_WithNewEvictionMaxIdle() throws InterruptedException { var mockConnectionManager = mock(PoolingNHttpClientConnectionManager.class); Settings settings = Settings.builder() @@ -104,25 +101,17 @@ public void testStartsANewEvictor_WithNewEvictionMaxIdle() throws InterruptedExc .build(); var manager = new HttpClientManager(settings, mockConnectionManager, threadPool, mockClusterService(settings)); - CountDownLatch runLatch = new CountDownLatch(1); - doAnswer(invocation -> { - manager.close(); - runLatch.countDown(); - return Void.TYPE; - }).when(mockConnectionManager).closeIdleConnections(anyLong(), any()); - var evictionMaxIdle = TimeValue.timeValueSeconds(1); manager.setEvictionMaxIdle(evictionMaxIdle); - runLatch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); - verify(mockConnectionManager, times(1)).closeIdleConnections(eq(evictionMaxIdle.millis()), eq(TimeUnit.MILLISECONDS)); + assertFalse(manager.isEvictionThreadRunning()); } - private static ClusterService mockClusterServiceEmpty() { + public static ClusterService mockClusterServiceEmpty() { return mockClusterService(Settings.EMPTY); } - private static ClusterService mockClusterService(Settings settings) { + public static ClusterService mockClusterService(Settings settings) { var clusterService = mock(ClusterService.class); var registeredSettings = Stream.concat(HttpClientManager.getSettings().stream(), HttpSettings.getSettings().stream()) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientTests.java index b0b0a34aabf97..0fee565716304 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientTests.java @@ -10,6 +10,8 @@ import org.apache.http.HttpHeaders; import org.apache.http.HttpResponse; import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.client.protocol.HttpClientContext; import org.apache.http.client.utils.URIBuilder; import org.apache.http.concurrent.FutureCallback; import org.apache.http.entity.ByteArrayEntity; @@ -45,7 +47,6 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.InferencePlugin.HTTP_CLIENT_SENDER_THREAD_POOL_NAME; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -87,7 +88,7 @@ public void testSend_MockServerReceivesRequest() throws Exception { httpClient.start(); PlainActionFuture listener = new PlainActionFuture<>(); - httpClient.send(httpPost, listener); + httpClient.send(httpPost, HttpClientContext.create(), listener); var result = listener.actionGet(TIMEOUT); @@ -100,15 +101,27 @@ public void testSend_MockServerReceivesRequest() throws Exception { } } + public void testSend_ThrowsErrorIfCalledBeforeStart() throws Exception { + try (var httpClient = HttpClient.create(emptyHttpSettings(), threadPool, createConnectionManager())) { + PlainActionFuture listener = new PlainActionFuture<>(); + var thrownException = expectThrows( + AssertionError.class, + () -> httpClient.send(mock(HttpUriRequest.class), HttpClientContext.create(), listener) + ); + + assertThat(thrownException.getMessage(), is("call start() before attempting to send a request")); + } + } + public void testSend_FailedCallsOnFailure() throws Exception { var asyncClient = mock(CloseableHttpAsyncClient.class); doAnswer(invocation -> { @SuppressWarnings("unchecked") - FutureCallback listener = (FutureCallback) invocation.getArguments()[1]; + FutureCallback listener = (FutureCallback) invocation.getArguments()[2]; listener.failed(new ElasticsearchException("failure")); return mock(Future.class); - }).when(asyncClient).execute(any(), any()); + }).when(asyncClient).execute(any(HttpUriRequest.class), any(), any()); var httpPost = createHttpPost(webServer.getPort(), "a", "b"); @@ -116,7 +129,7 @@ public void testSend_FailedCallsOnFailure() throws Exception { client.start(); PlainActionFuture listener = new PlainActionFuture<>(); - client.send(httpPost, listener); + client.send(httpPost, HttpClientContext.create(), listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(thrownException.getMessage(), is("failure")); @@ -128,10 +141,10 @@ public void testSend_CancelledCallsOnFailure() throws Exception { doAnswer(invocation -> { @SuppressWarnings("unchecked") - FutureCallback listener = (FutureCallback) invocation.getArguments()[1]; + FutureCallback listener = (FutureCallback) invocation.getArguments()[2]; listener.cancelled(); return mock(Future.class); - }).when(asyncClient).execute(any(), any()); + }).when(asyncClient).execute(any(HttpUriRequest.class), any(), any()); var httpPost = createHttpPost(webServer.getPort(), "a", "b"); @@ -139,7 +152,7 @@ public void testSend_CancelledCallsOnFailure() throws Exception { client.start(); PlainActionFuture listener = new PlainActionFuture<>(); - client.send(httpPost, listener); + client.send(httpPost, HttpClientContext.create(), listener); var thrownException = expectThrows(CancellationException.class, () -> listener.actionGet(TIMEOUT)); assertThat(thrownException.getMessage(), is(format("Request [%s] was cancelled", httpPost.getRequestLine()))); @@ -149,7 +162,7 @@ public void testSend_CancelledCallsOnFailure() throws Exception { @SuppressWarnings("unchecked") public void testStart_MultipleCallsOnlyStartTheClientOnce() throws Exception { var asyncClient = mock(CloseableHttpAsyncClient.class); - when(asyncClient.execute(any(), any())).thenReturn(mock(Future.class)); + when(asyncClient.execute(any(HttpUriRequest.class), any(), any())).thenReturn(mock(Future.class)); var httpPost = createHttpPost(webServer.getPort(), "a", "b"); @@ -157,8 +170,8 @@ public void testStart_MultipleCallsOnlyStartTheClientOnce() throws Exception { client.start(); PlainActionFuture listener = new PlainActionFuture<>(); - client.send(httpPost, listener); - client.send(httpPost, listener); + client.send(httpPost, HttpClientContext.create(), listener); + client.send(httpPost, HttpClientContext.create(), listener); verify(asyncClient, times(1)).start(); } @@ -180,7 +193,7 @@ public void testSend_FailsWhenMaxBytesReadIsExceeded() throws Exception { httpClient.start(); PlainActionFuture listener = new PlainActionFuture<>(); - httpClient.send(httpPost, listener); + httpClient.send(httpPost, HttpClientContext.create(), listener); var throwException = expectThrows(UncategorizedExecutionException.class, () -> listener.actionGet(TIMEOUT)); assertThat(throwException.getCause().getCause().getMessage(), is("Maximum limit of [1] bytes reached")); @@ -217,19 +230,11 @@ public static ThreadPool createThreadPool(String name) { TimeValue.timeValueMinutes(10), false, "xpack.inference.utility_thread_pool" - ), - new ScalingExecutorBuilder( - HTTP_CLIENT_SENDER_THREAD_POOL_NAME, - 1, - 4, - TimeValue.timeValueMinutes(10), - false, - "xpack.inference.utility_thread_pool" ) ); } - private static PoolingNHttpClientConnectionManager createConnectionManager() throws IOReactorException { + public static PoolingNHttpClientConnectionManager createConnectionManager() throws IOReactorException { return new PoolingNHttpClientConnectionManager(new DefaultConnectingIOReactor()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/IdleConnectionEvictorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/IdleConnectionEvictorTests.java index 2cc00f35f9af6..0f03003589073 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/IdleConnectionEvictorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/IdleConnectionEvictorTests.java @@ -12,9 +12,7 @@ import org.apache.http.nio.reactor.IOReactorException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.threadpool.Scheduler; -import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.junit.After; import org.junit.Before; @@ -22,7 +20,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.createThreadPool; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.doAnswer; @@ -38,17 +36,7 @@ public class IdleConnectionEvictorTests extends ESTestCase { @Before public void init() { - threadPool = new TestThreadPool( - getTestName(), - new ScalingExecutorBuilder( - UTILITY_THREAD_POOL_NAME, - 1, - 4, - TimeValue.timeValueMinutes(10), - false, - "xpack.inference.utility_thread_pool" - ) - ); + threadPool = createThreadPool(getTestName()); } @After @@ -103,7 +91,7 @@ public void testCloseExpiredConnections_IsCalled() throws InterruptedException { CountDownLatch runLatch = new CountDownLatch(1); doAnswer(invocation -> { - evictor.stop(); + evictor.close(); runLatch.countDown(); return Void.TYPE; }).when(manager).closeExpiredConnections(); @@ -126,7 +114,7 @@ public void testCloseIdleConnections_IsCalled() throws InterruptedException { CountDownLatch runLatch = new CountDownLatch(1); doAnswer(invocation -> { - evictor.stop(); + evictor.close(); runLatch.countDown(); return Void.TYPE; }).when(manager).closeIdleConnections(anyLong(), any()); @@ -147,7 +135,7 @@ public void testIsRunning_ReturnsTrue() throws IOReactorException { evictor.start(); assertTrue(evictor.isRunning()); - evictor.stop(); + evictor.close(); } public void testIsRunning_ReturnsFalse() throws IOReactorException { @@ -161,7 +149,7 @@ public void testIsRunning_ReturnsFalse() throws IOReactorException { evictor.start(); assertTrue(evictor.isRunning()); - evictor.stop(); + evictor.close(); assertFalse(evictor.isRunning()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorServiceTests.java new file mode 100644 index 0000000000000..4a658c4aa00ef --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestExecutorServiceTests.java @@ -0,0 +1,182 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.methods.HttpRequestBase; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.HttpClient; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.junit.After; +import org.junit.Before; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.createThreadPool; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + +public class HttpRequestExecutorServiceTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private ThreadPool threadPool; + + @Before + public void init() { + threadPool = createThreadPool(getTestName()); + } + + @After + public void shutdown() { + terminate(threadPool); + } + + public void testQueueSize_IsEmpty() { + var service = new HttpRequestExecutorService(getTestName(), mock(HttpClient.class)); + + assertThat(service.queueSize(), is(0)); + } + + public void testQueueSize_IsOne() { + var service = new HttpRequestExecutorService(getTestName(), mock(HttpClient.class)); + service.send(mock(HttpRequestBase.class), new PlainActionFuture<>()); + + assertThat(service.queueSize(), is(1)); + } + + public void testExecute_ThrowsUnsupported() { + var service = new HttpRequestExecutorService(getTestName(), mock(HttpClient.class)); + var noopTask = mock(RequestTask.class); + + var thrownException = expectThrows(UnsupportedOperationException.class, () -> service.execute(noopTask)); + assertThat(thrownException.getMessage(), is("use send instead")); + } + + public void testIsTerminated_IsFalse() { + var service = new HttpRequestExecutorService(getTestName(), mock(HttpClient.class)); + + assertFalse(service.isTerminated()); + } + + public void testIsTerminated_IsTrue() { + var service = new HttpRequestExecutorService(getTestName(), mock(HttpClient.class)); + + service.shutdown(); + service.start(); + + assertTrue(service.isTerminated()); + } + + public void testIsTerminated_AfterStopFromSeparateThread() throws Exception { + var waitToShutdown = new CountDownLatch(1); + + var mockHttpClient = mock(HttpClient.class); + doAnswer(invocation -> { + waitToShutdown.countDown(); + return Void.TYPE; + }).when(mockHttpClient).send(any(), any(), any()); + + var service = new HttpRequestExecutorService(getTestName(), mockHttpClient); + + Future executorTermination = threadPool.generic().submit(() -> { + try { + // wait for a task to be added to be executed before beginning shutdown + waitToShutdown.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + service.shutdown(); + service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + } catch (Exception e) { + fail(Strings.format("Failed to shutdown executor: %s", e)); + } + }); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.send(mock(HttpRequestBase.class), listener); + + service.start(); + + try { + executorTermination.get(1, TimeUnit.SECONDS); + } catch (Exception e) { + fail(Strings.format("Executor finished before it was signaled to shutdown: %s", e)); + } + + assertTrue(service.isShutdown()); + assertTrue(service.isTerminated()); + } + + public void testExecute_AfterShutdown_Throws() { + var service = new HttpRequestExecutorService("test_service", mock(HttpClient.class)); + + service.shutdown(); + + var listener = new PlainActionFuture(); + service.send(mock(HttpRequestBase.class), listener); + + var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is("Failed to execute task because the http executor service [test_service] has shutdown") + ); + } + + public void testExecute_Throws_WhenQueueIsFull() { + var service = new HttpRequestExecutorService("test_service", mock(HttpClient.class), 1); + + service.send(mock(HttpRequestBase.class), new PlainActionFuture<>()); + var listener = new PlainActionFuture(); + service.send(mock(HttpRequestBase.class), listener); + + var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is("Failed to execute task because the http executor service [test_service] queue is full") + ); + } + + public void testTaskThrowsError_CallsOnFailure() throws Exception { + var httpClient = mock(HttpClient.class); + + var service = new HttpRequestExecutorService(getTestName(), httpClient); + + doAnswer(invocation -> { + service.shutdown(); + throw new ElasticsearchException("failed"); + }).when(httpClient).send(any(), any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + service.send(mock(HttpRequestBase.class), listener); + service.start(); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("failed")); + assertTrue(service.isTerminated()); + } + + public void testShutdown_AllowsMultipleCalls() { + var service = new HttpRequestExecutorService(getTestName(), mock(HttpClient.class)); + + service.shutdown(); + service.shutdown(); + service.shutdownNow(); + service.start(); + + assertTrue(service.isTerminated()); + assertTrue(service.isShutdown()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactoryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactoryTests.java new file mode 100644 index 0000000000000..47ecde6ae535c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderFactoryTests.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpRequestBase; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.external.http.HttpClientManagerTests.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.createHttpPost; +import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.createThreadPool; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class HttpRequestSenderFactoryTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + private Thread thread; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(getTestName()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty()); + thread = null; + } + + @After + public void shutdown() throws IOException, InterruptedException { + if (thread != null) { + thread.join(TIMEOUT.millis()); + } + + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception { + var mockExecutorService = mock(ExecutorService.class); + doAnswer(invocation -> { + Runnable runnable = (Runnable) invocation.getArguments()[0]; + thread = new Thread(runnable); + thread.start(); + + return Void.TYPE; + }).when(mockExecutorService).execute(any(Runnable.class)); + + var mockThreadPool = mock(ThreadPool.class); + when(mockThreadPool.executor(anyString())).thenReturn(mockExecutorService); + when(mockThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + + var senderFactory = new HttpRequestSenderFactory(mockThreadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + int responseCode = randomIntBetween(200, 203); + String body = randomAlphaOfLengthBetween(2, 8096); + webServer.enqueue(new MockResponse().setResponseCode(responseCode).setBody(body)); + + String paramKey = randomAlphaOfLength(3); + String paramValue = randomAlphaOfLength(3); + var httpPost = createHttpPost(webServer.getPort(), paramKey, paramValue); + + PlainActionFuture listener = new PlainActionFuture<>(); + sender.send(httpPost, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.response().getStatusLine().getStatusCode(), equalTo(responseCode)); + assertThat(new String(result.body(), StandardCharsets.UTF_8), is(body)); + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().get(0).getUri().getPath(), equalTo(httpPost.getURI().getPath())); + assertThat(webServer.requests().get(0).getUri().getQuery(), equalTo(paramKey + "=" + paramValue)); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + } + } + + public void testHttpRequestSender_Throws_WhenCallingSendBeforeStart() throws Exception { + var senderFactory = new HttpRequestSenderFactory(threadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + PlainActionFuture listener = new PlainActionFuture<>(); + var thrownException = expectThrows(AssertionError.class, () -> sender.send(mock(HttpRequestBase.class), listener)); + assertThat(thrownException.getMessage(), is("call start() before sending a request")); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java new file mode 100644 index 0000000000000..37f4fb8cce4cb --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.protocol.HttpClientContext; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpClient; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.createConnectionManager; +import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.createHttpPost; +import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.createThreadPool; +import static org.elasticsearch.xpack.inference.external.http.HttpClientTests.emptyHttpSettings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class RequestTaskTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(getTestName()); + } + + @After + public void shutdown() { + terminate(threadPool); + webServer.close(); + } + + public void testDoRun_SendsRequestAndReceivesResponse() throws Exception { + int responseCode = randomIntBetween(200, 203); + String body = randomAlphaOfLengthBetween(2, 8096); + webServer.enqueue(new MockResponse().setResponseCode(responseCode).setBody(body)); + + String paramKey = randomAlphaOfLength(3); + String paramValue = randomAlphaOfLength(3); + var httpPost = createHttpPost(webServer.getPort(), paramKey, paramValue); + + try (var httpClient = HttpClient.create(emptyHttpSettings(), threadPool, createConnectionManager())) { + httpClient.start(); + + PlainActionFuture listener = new PlainActionFuture<>(); + var requestTask = new RequestTask(httpPost, httpClient, HttpClientContext.create(), listener); + requestTask.doRun(); + var result = listener.actionGet(TIMEOUT); + + assertThat(result.response().getStatusLine().getStatusCode(), equalTo(responseCode)); + assertThat(new String(result.body(), StandardCharsets.UTF_8), is(body)); + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().get(0).getUri().getPath(), equalTo(httpPost.getURI().getPath())); + assertThat(webServer.requests().get(0).getUri().getQuery(), equalTo(paramKey + "=" + paramValue)); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + } + } + + public void testDoRun_SendThrowsIOException() throws Exception { + var httpClient = mock(HttpClient.class); + doThrow(new IOException("exception")).when(httpClient).send(any(), any(), any()); + + String paramKey = randomAlphaOfLength(3); + String paramValue = randomAlphaOfLength(3); + var httpPost = createHttpPost(webServer.getPort(), paramKey, paramValue); + + PlainActionFuture listener = new PlainActionFuture<>(); + var requestTask = new RequestTask(httpPost, httpClient, HttpClientContext.create(), listener); + requestTask.doRun(); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is(format("Failed to send request [%s]", httpPost.getRequestLine()))); + } +}