Skip to content

Commit

Permalink
Using a custom method for sending requests in the queue
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-buttner committed Oct 12, 2023
1 parent b70f418 commit 0cb27e0
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.xpack.inference.external.http.HttpClient;
import org.elasticsearch.xpack.inference.external.http.HttpResult;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -40,7 +44,7 @@
* attempting to execute a task (aka waiting for the connection manager to lease a connection). See
* {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info.
*/
public class HttpRequestExecutorService extends AbstractExecutorService {
class HttpRequestExecutorService extends AbstractExecutorService {
private static final Logger logger = LogManager.getLogger(HttpRequestExecutorService.class);

private final ThreadContext contextHolder;
Expand All @@ -49,16 +53,18 @@ public class HttpRequestExecutorService extends AbstractExecutorService {
private final AtomicBoolean running = new AtomicBoolean(true);
private final CountDownLatch terminationLatch = new CountDownLatch(1);
private final HttpClientContext httpContext;
private final HttpClient httpClient;

@SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors")
public HttpRequestExecutorService(ThreadContext contextHolder, String serviceName) {
this(contextHolder, serviceName, null);
HttpRequestExecutorService(ThreadContext contextHolder, String serviceName, HttpClient httpClient) {
this(contextHolder, serviceName, httpClient, null);
}

@SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors")
public HttpRequestExecutorService(ThreadContext contextHolder, String serviceName, @Nullable Integer capacity) {
HttpRequestExecutorService(ThreadContext contextHolder, String serviceName, HttpClient httpClient, @Nullable Integer capacity) {
this.contextHolder = Objects.requireNonNull(contextHolder);
this.serviceName = Objects.requireNonNull(serviceName);
this.httpClient = Objects.requireNonNull(httpClient);
this.httpContext = HttpClientContext.create();

if (capacity == null) {
Expand Down Expand Up @@ -132,18 +138,12 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE
}

/**
* Execute the task at some point in the future.
* @param command the runnable task, must be a class that extends {@link HttpTask}
* Send the request at some point in the future.
* @param request the http request to send
* @param listener an {@link ActionListener<HttpResult>} for the response or failure
*/
@Override
public void execute(Runnable command) {
if (command == null) {
return;
}

assert command instanceof HttpTask;
HttpTask task = (HttpTask) command;
task.setContext(httpContext);
public void send(HttpRequestBase request, ActionListener<HttpResult> listener) {
RequestTask task = new RequestTask(request, httpClient, httpContext, listener);

if (isShutdown()) {
EsRejectedExecutionException rejected = new EsRejectedExecutionException(
Expand All @@ -165,4 +165,13 @@ public void execute(Runnable command) {
task.onRejection(rejected);
}
}

/**
* This method is not supported. Use {@link #send} instead.
* @param runnable the runnable task
*/
@Override
public void execute(Runnable runnable) {
throw new UnsupportedOperationException("use send instead");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.client.methods.HttpUriRequest;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -49,7 +50,7 @@ public static final class HttpRequestSender implements Closeable {
private HttpRequestSender(String serviceName, ThreadPool threadPool, HttpClientManager httpClientManager) {
this.threadPool = Objects.requireNonNull(threadPool);
this.manager = Objects.requireNonNull(httpClientManager);
service = new HttpRequestExecutorService(threadPool.getThreadContext(), serviceName);
service = new HttpRequestExecutorService(threadPool.getThreadContext(), serviceName, manager.getHttpClient());
}

/**
Expand All @@ -68,9 +69,9 @@ public void close() throws IOException {
service.shutdown();
}

public void send(HttpUriRequest request, ActionListener<HttpResult> listener) {
public void send(HttpRequestBase request, ActionListener<HttpResult> listener) {
assert started.get() : "call start() before sending a request";
service.execute(new RequestTask(request, manager.getHttpClient(), listener));
service.send(request, listener);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,10 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;

public abstract class HttpTask extends AbstractRunnable {
protected HttpClientContext context;

abstract class HttpTask extends AbstractRunnable {
public boolean shouldShutdown() {
return false;
}

public void setContext(HttpClientContext context) {
this.context = context;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
Expand All @@ -20,17 +21,19 @@

import static org.elasticsearch.core.Strings.format;

public class RequestTask extends HttpTask {
class RequestTask extends HttpTask {
private static final Logger logger = LogManager.getLogger(RequestTask.class);

private final HttpUriRequest request;
private final ActionListener<HttpResult> listener;
private final HttpClient httpClient;
private final HttpClientContext context;

public RequestTask(HttpUriRequest request, HttpClient httpClient, ActionListener<HttpResult> listener) {
RequestTask(HttpUriRequest request, HttpClient httpClient, HttpClientContext context, ActionListener<HttpResult> listener) {
this.request = Objects.requireNonNull(request);
this.httpClient = Objects.requireNonNull(httpClient);
this.listener = Objects.requireNonNull(listener);
this.context = Objects.requireNonNull(context);
}

@Override
Expand All @@ -40,8 +43,6 @@ public void onFailure(Exception e) {

@Override
protected void doRun() throws Exception {
assert context != null : "the http context must be set before calling doRun";

try {
httpClient.send(request, context, listener);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

package org.elasticsearch.xpack.inference.external.http.sender;

public class ShutdownTask extends HttpTask {
class ShutdownTask extends HttpTask {
@Override
public boolean shouldShutdown() {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.methods.HttpRequestBase;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.HttpClient;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;
Expand All @@ -29,8 +29,6 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class HttpRequestExecutorServiceTests extends ESTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
Expand All @@ -47,28 +45,34 @@ public void shutdown() {
}

public void testQueueSize_IsEmpty() {
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName());
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mock(HttpClient.class));

assertThat(service.queueSize(), is(0));
}

public void testQueueSize_IsOne() {
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName());
var noopTask = mock(RequestTask.class);

service.execute(noopTask);
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mock(HttpClient.class));
service.send(mock(HttpRequestBase.class), new PlainActionFuture<>());

assertThat(service.queueSize(), is(1));
}

public void testExecute_ThrowsUnsupported() {
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mock(HttpClient.class));
var noopTask = mock(RequestTask.class);

var thrownException = expectThrows(UnsupportedOperationException.class, () -> service.execute(noopTask));
assertThat(thrownException.getMessage(), is("use send instead"));
}

public void testIsTerminated_IsFalse() {
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName());
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mock(HttpClient.class));

assertFalse(service.isTerminated());
}

public void testIsTerminated_IsTrue() {
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName());
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mock(HttpClient.class));

service.shutdown();
service.start();
Expand All @@ -77,9 +81,16 @@ public void testIsTerminated_IsTrue() {
}

public void testIsTerminated_AfterStopFromSeparateThread() throws Exception {
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName());

var waitToShutdown = new CountDownLatch(1);

var mockHttpClient = mock(HttpClient.class);
doAnswer(invocation -> {
waitToShutdown.countDown();
return Void.TYPE;
}).when(mockHttpClient).send(any(), any(), any());

var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mockHttpClient);

Future<?> executorTermination = threadPool.generic().submit(() -> {
try {
// wait for a task to be added to be executed before beginning shutdown
Expand All @@ -91,13 +102,8 @@ public void testIsTerminated_AfterStopFromSeparateThread() throws Exception {
}
});

var blockingTask = mock(RequestTask.class);
doAnswer(invocation -> {
waitToShutdown.countDown();
return Void.TYPE;
}).when(blockingTask).doRun();

service.execute(blockingTask);
PlainActionFuture<HttpResult> listener = new PlainActionFuture<>();
service.send(mock(HttpRequestBase.class), listener);

service.start();

Expand All @@ -112,50 +118,49 @@ public void testIsTerminated_AfterStopFromSeparateThread() throws Exception {
}

public void testExecute_AfterShutdown_Throws() {
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), "test_service");
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), "test_service", mock(HttpClient.class));

service.shutdown();

var task = mock(RequestTask.class);
service.execute(task);
var listener = new PlainActionFuture<HttpResult>();
service.send(mock(HttpRequestBase.class), listener);

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(task, times(1)).onRejection(exceptionCaptor.capture());
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));

assertThat(
exceptionCaptor.getValue().getMessage(),
thrownException.getMessage(),
is("Failed to execute task because the http executor service [test_service] has shutdown")
);
}

public void testExecute_Throws_WhenQueueIsFull() {
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), "test_service", 1);
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), "test_service", mock(HttpClient.class), 1);

var task = mock(RequestTask.class);
var rejectedTask = mock(RequestTask.class);
service.execute(task);
service.execute(rejectedTask);
service.send(mock(HttpRequestBase.class), new PlainActionFuture<>());
var listener = new PlainActionFuture<HttpResult>();
service.send(mock(HttpRequestBase.class), listener);

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(rejectedTask, times(1)).onRejection(exceptionCaptor.capture());
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));

assertThat(
exceptionCaptor.getValue().getMessage(),
thrownException.getMessage(),
is("Failed to execute task because the http executor service [test_service] queue is full")
);
}

public void testTaskThrowsError_CallsOnFailure() throws Exception {
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName());

var httpClient = mock(HttpClient.class);
doAnswer(invocation -> { throw new ElasticsearchException("failed"); }).when(httpClient).send(any(), any(), any());

var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), httpClient);

doAnswer(invocation -> {
service.shutdown();
throw new ElasticsearchException("failed");
}).when(httpClient).send(any(), any(), any());

PlainActionFuture<HttpResult> listener = new PlainActionFuture<>();
var failingTask = new RequestTask(mock(HttpUriRequest.class), httpClient, listener);

service.execute(failingTask);
service.execute(new ShutdownTask());
service.send(mock(HttpRequestBase.class), listener);
service.start();

var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
Expand All @@ -164,7 +169,7 @@ public void testTaskThrowsError_CallsOnFailure() throws Exception {
}

public void testShutdown_AllowsMultipleCalls() {
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName());
var service = new HttpRequestExecutorService(threadPool.getThreadContext(), getTestName(), mock(HttpClient.class));

service.shutdown();
service.shutdown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.methods.HttpRequestBase;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
Expand Down Expand Up @@ -112,7 +112,7 @@ public void testHttpRequestSender_Throws_WhenCallingSendBeforeStart() throws Exc

try (var sender = senderFactory.createSender("test_service")) {
PlainActionFuture<HttpResult> listener = new PlainActionFuture<>();
var thrownException = expectThrows(AssertionError.class, () -> sender.send(mock(HttpUriRequest.class), listener));
var thrownException = expectThrows(AssertionError.class, () -> sender.send(mock(HttpRequestBase.class), listener));
assertThat(thrownException.getMessage(), is("call start() before sending a request"));
}
}
Expand Down
Loading

0 comments on commit 0cb27e0

Please sign in to comment.