Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retain ref to requests when running ActionFilterChain #104000

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskManager;

Expand Down Expand Up @@ -58,8 +60,13 @@ public final void execute(Task task, Request request, ActionListener<Response> l
listener = new TaskResultStoringActionListener<>(taskManager, task, listener);
}

RequestFilterChain<Request, Response> requestFilterChain = new RequestFilterChain<>(this, logger);
requestFilterChain.proceed(task, actionName, request, listener);
// Note on request refcounting: we can be sure that either we get to the end of the chain (and execute the actual action) or
// we complete the response listener and short-circuit the outer chain, so we release our request ref on both paths, using
// Releasables#releaseOnce to avoid a double-release.
request.mustIncRef();
final var releaseRef = Releasables.releaseOnce(request::decRef);
RequestFilterChain<Request, Response> requestFilterChain = new RequestFilterChain<>(this, logger, releaseRef);
requestFilterChain.proceed(task, actionName, request, ActionListener.runBefore(listener, releaseRef::close));
}

protected abstract void doExecute(Task task, Request request, ActionListener<Response> listener);
Expand All @@ -71,10 +78,12 @@ private static class RequestFilterChain<Request extends ActionRequest, Response
private final TransportAction<Request, Response> action;
private final AtomicInteger index = new AtomicInteger();
private final Logger logger;
private final Releasable releaseRef;

private RequestFilterChain(TransportAction<Request, Response> action, Logger logger) {
private RequestFilterChain(TransportAction<Request, Response> action, Logger logger, Releasable releaseRef) {
this.action = action;
this.logger = logger;
this.releaseRef = releaseRef;
}

@Override
Expand All @@ -84,7 +93,9 @@ public void proceed(Task task, String actionName, Request request, ActionListene
if (i < this.action.filters.length) {
this.action.filters[i].apply(task, actionName, request, listener, this);
} else if (i == this.action.filters.length) {
this.action.doExecute(task, request, listener);
try (releaseRef) {
this.action.doExecute(task, request, listener);
}
} else {
listener.onFailure(new IllegalStateException("proceed was called too many times"));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.action.support;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.LeakTracker;
import org.elasticsearch.transport.TransportService;

import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;

public class TransportActionFilterChainRefCountingTests extends ESSingleNodeTestCase {

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return List.of(TestPlugin.class);
}

static final ActionType<Response> TYPE = ActionType.localOnly("test:action");

public void testAsyncActionFilterRefCounting() {
final var countDownLatch = new CountDownLatch(2);
final var request = new Request();
try {
client().execute(TYPE, request, ActionListener.<Response>running(countDownLatch::countDown).delegateResponse((delegate, e) -> {
assertEquals("short-circuit failure", asInstanceOf(ElasticsearchException.class, e).getMessage());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this guaranteed to always be true ? (due to the random boolean below)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's ok if everything succeeds too, we're just checking that nothing else went wrong. I added a code comment in 84a194c.

delegate.onResponse(null);
}));
} finally {
request.decRef();
}
request.addCloseListener(ActionListener.running(countDownLatch::countDown));
safeAwait(countDownLatch);
}

public static class TestPlugin extends Plugin implements ActionPlugin {

private ThreadPool threadPool;

@Override
public Collection<?> createComponents(PluginServices services) {
threadPool = services.threadPool();
return List.of();
}

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return List.of(new ActionHandler<>(TYPE, TestAction.class));
}

@Override
public List<ActionFilter> getActionFilters() {
return randomSubsetOf(
List.of(
new TestAsyncActionFilter(threadPool),
new TestAsyncActionFilter(threadPool),
new TestAsyncMappedActionFilter(threadPool),
new TestAsyncMappedActionFilter(threadPool)
)
);
}
}

private static class TestAsyncActionFilter implements ActionFilter {

private final ThreadPool threadPool;
private final int order = randomInt();

private TestAsyncActionFilter(ThreadPool threadPool) {
this.threadPool = Objects.requireNonNull(threadPool);
}

@Override
public int order() {
return order;
}

@Override
public <Req extends ActionRequest, Rsp extends ActionResponse> void apply(
Task task,
String action,
Req request,
ActionListener<Rsp> listener,
ActionFilterChain<Req, Rsp> chain
) {
if (action.equals(TYPE.name())) {
randomFrom(EsExecutors.DIRECT_EXECUTOR_SERVICE, threadPool.generic()).execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
fail(e);
}

@Override
protected void doRun() {
Thread.yield();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this meant to mimic work?

Copy link
Contributor Author

@DaveCTurner DaveCTurner Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kinda, but really I was just using it to help make sure I had the right things covered by the tests and forgot to remove it. Gone now.

assertTrue(request.hasReferences());
if (randomBoolean()) {
chain.proceed(task, action, request, listener);
} else {
listener.onFailure(new ElasticsearchException("short-circuit failure"));
}
}
});
} else {
chain.proceed(task, action, request, listener);
}
}
}

private static class TestAsyncMappedActionFilter extends TestAsyncActionFilter implements MappedActionFilter {

private TestAsyncMappedActionFilter(ThreadPool threadPool) {
super(threadPool);
}

@Override
public String actionName() {
return TYPE.name();
}
}

public static class TestAction extends TransportAction<Request, Response> {

private final ThreadPool threadPool;

@Inject
public TestAction(TransportService transportService, ActionFilters actionFilters) {
super(TYPE.name(), actionFilters, transportService.getTaskManager());
threadPool = transportService.getThreadPool();
}

@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
request.mustIncRef();
threadPool.generic().execute(ActionRunnable.supply(ActionListener.runBefore(listener, request::decRef), () -> {
Thread.yield();
assert request.hasReferences();
return new Response();
}));
}
}

private static class Request extends ActionRequest {
private final SubscribableListener<Void> closeListeners = new SubscribableListener<>();
private final RefCounted refs = LeakTracker.wrap(AbstractRefCounted.of(() -> closeListeners.onResponse(null)));

@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void incRef() {
refs.incRef();
}

@Override
public boolean tryIncRef() {
return refs.tryIncRef();
}

@Override
public boolean decRef() {
return refs.decRef();
}

@Override
public boolean hasReferences() {
return refs.hasReferences();
}

void addCloseListener(ActionListener<Void> listener) {
closeListeners.addListener(listener);
}
}

private static class Response extends ActionResponse {
@Override
public void writeTo(StreamOutput out) {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ protected void doExecute(Task task, TestRequest request, ActionListener<TestResp
}

public void testTooManyContinueProcessingRequest() throws InterruptedException {
final int additionalContinueCount = randomInt(10);

RequestTestFilter testFilter = new RequestTestFilter(randomInt(), new RequestCallback() {
@Override
public <Request extends ActionRequest, Response extends ActionResponse> void execute(
Expand All @@ -146,15 +144,18 @@ public <Request extends ActionRequest, Response extends ActionResponse> void exe
ActionListener<Response> listener,
ActionFilterChain<Request, Response> actionFilterChain
) {
for (int i = 0; i <= additionalContinueCount; i++) {
actionFilterChain.proceed(task, action, request, listener);
}
// expected proceed() call:
actionFilterChain.proceed(task, action, request, listener);

// extra, invalid, proceed() call:
actionFilterChain.proceed(task, action, request, listener);
}
});

Set<ActionFilter> filters = new HashSet<>();
filters.add(testFilter);

final CountDownLatch latch = new CountDownLatch(2);
String actionName = randomAlphaOfLength(randomInt(30));
ActionFilters actionFilters = new ActionFilters(filters);
TransportAction<TestRequest, TestResponse> transportAction = new TransportAction<TestRequest, TestResponse>(
Expand All @@ -164,18 +165,16 @@ public <Request extends ActionRequest, Response extends ActionResponse> void exe
) {
@Override
protected void doExecute(Task task, TestRequest request, ActionListener<TestResponse> listener) {
listener.onResponse(new TestResponse());
latch.countDown();
}
};

final CountDownLatch latch = new CountDownLatch(additionalContinueCount + 1);
final AtomicInteger responses = new AtomicInteger();
final List<Throwable> failures = new CopyOnWriteArrayList<>();

ActionTestUtils.execute(transportAction, null, new TestRequest(), new LatchedActionListener<>(new ActionListener<>() {
@Override
public void onResponse(TestResponse testResponse) {
responses.incrementAndGet();
fail("should not complete listener");
}

@Override
Expand All @@ -191,8 +190,7 @@ public void onFailure(Exception e) {
assertThat(testFilter.runs.get(), equalTo(1));
assertThat(testFilter.lastActionName, equalTo(actionName));

assertThat(responses.get(), equalTo(1));
assertThat(failures.size(), equalTo(additionalContinueCount));
assertThat(failures.size(), equalTo(1));
for (Throwable failure : failures) {
assertThat(failure, instanceOf(IllegalStateException.class));
}
Expand Down