Skip to content

Commit

Permalink
Fix ref counting when using futures in AbstractClient (#102498)
Browse files Browse the repository at this point in the history
Found this now that #102030 is getting closer to completion. It's
fundamentally broken how the client deals with ref counted messages. We
were only saved by the fact that currently we do do not handle any
messages with non-noop ref counting through this client interface. 

We fundamentally need to increment the ref count by one before assigning
a ref counted value to a future result. Otherwise, transport actions
will have to understand the specific kind of listener they resolve and
cannot decrement sent/consumed transport messages themselves.

marking non-issue since this hasn't caused any production trouble yet.
  • Loading branch information
original-brownbear authored Nov 23, 2023
1 parent dee5b6f commit 659b236
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,15 @@ public int read() throws IOException {
public void testIndexChunksNoData() throws IOException {
client.addHandler(FlushAction.INSTANCE, (FlushRequest request, ActionListener<FlushResponse> flushResponseActionListener) -> {
assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
flushResponseActionListener.onResponse(mock(FlushResponse.class));
var flushResponse = mock(FlushResponse.class);
when(flushResponse.hasReferences()).thenReturn(true);
flushResponseActionListener.onResponse(flushResponse);
});
client.addHandler(RefreshAction.INSTANCE, (RefreshRequest request, ActionListener<RefreshResponse> flushResponseActionListener) -> {
assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
flushResponseActionListener.onResponse(mock(RefreshResponse.class));
var refreshResponse = mock(RefreshResponse.class);
when(refreshResponse.hasReferences()).thenReturn(true);
flushResponseActionListener.onResponse(refreshResponse);
});

InputStream empty = new ByteArrayInputStream(new byte[0]);
Expand All @@ -194,11 +198,15 @@ public void testIndexChunksNoData() throws IOException {
public void testIndexChunksMd5Mismatch() {
client.addHandler(FlushAction.INSTANCE, (FlushRequest request, ActionListener<FlushResponse> flushResponseActionListener) -> {
assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
flushResponseActionListener.onResponse(mock(FlushResponse.class));
var flushResponse = mock(FlushResponse.class);
when(flushResponse.hasReferences()).thenReturn(true);
flushResponseActionListener.onResponse(flushResponse);
});
client.addHandler(RefreshAction.INSTANCE, (RefreshRequest request, ActionListener<RefreshResponse> flushResponseActionListener) -> {
assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
flushResponseActionListener.onResponse(mock(RefreshResponse.class));
var refreshResponse = mock(RefreshResponse.class);
when(refreshResponse.hasReferences()).thenReturn(true);
flushResponseActionListener.onResponse(refreshResponse);
});

IOException exception = expectThrows(
Expand Down Expand Up @@ -230,15 +238,21 @@ public void testIndexChunks() throws IOException {
assertEquals("test", source.get("name"));
assertArrayEquals(chunksData[chunk], (byte[]) source.get("data"));
assertEquals(chunk + 15, source.get("chunk"));
listener.onResponse(mock(IndexResponse.class));
var indexResponse = mock(IndexResponse.class);
when(indexResponse.hasReferences()).thenReturn(true);
listener.onResponse(indexResponse);
});
client.addHandler(FlushAction.INSTANCE, (FlushRequest request, ActionListener<FlushResponse> flushResponseActionListener) -> {
assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
flushResponseActionListener.onResponse(mock(FlushResponse.class));
var flushResponse = mock(FlushResponse.class);
when(flushResponse.hasReferences()).thenReturn(true);
flushResponseActionListener.onResponse(flushResponse);
});
client.addHandler(RefreshAction.INSTANCE, (RefreshRequest request, ActionListener<RefreshResponse> flushResponseActionListener) -> {
assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
flushResponseActionListener.onResponse(mock(RefreshResponse.class));
var refreshResponse = mock(RefreshResponse.class);
when(refreshResponse.hasReferences()).thenReturn(true);
flushResponseActionListener.onResponse(refreshResponse);
});

InputStream big = new ByteArrayInputStream(bigArray);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,18 +199,21 @@ private void testCase(
logger.info("Starting request");
ActionFuture<BulkByScrollResponse> responseListener = builder.execute();

BulkByScrollResponse response = null;
try {
logger.info("Waiting for bulk rejections");
assertBusy(() -> assertThat(taskStatus(action).getBulkRetries(), greaterThan(0L)));
bulkBlock.await();

logger.info("Waiting for the request to finish");
BulkByScrollResponse response = responseListener.get();
response = responseListener.get();
assertThat(response, matcher);
assertThat(response.getBulkRetries(), greaterThan(0L));
} finally {
// Fetch the response just in case we blew up half way through. This will make sure the failure is thrown up to the top level.
BulkByScrollResponse response = responseListener.get();
if (response == null) {
response = responseListener.get();
}
assertThat(response.getSearchFailures(), empty());
assertThat(response.getBulkFailures(), empty());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ public void testCancelTaskMultipleTimes() throws Exception {
assertFalse(cancelFuture.isDone());
allowEntireRequest(rootRequest);
assertThat(cancelFuture.actionGet().getTaskFailures(), empty());
assertThat(cancelFuture.actionGet().getTaskFailures(), empty());
waitForRootTask(mainTaskFuture, false);
CancelTasksResponse cancelError = clusterAdmin().prepareCancelTasks()
.setTargetTaskId(taskId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,9 @@ public void testNoRefreshInterval() throws InterruptedException, ExecutionExcept
while (false == index.isDone()) {
indicesAdmin().prepareRefresh("test").get();
}
assertEquals(RestStatus.CREATED, index.get().status());
assertFalse("request shouldn't have forced a refresh", index.get().forcedRefresh());
var response = index.get();
assertEquals(RestStatus.CREATED, response.status());
assertFalse("request shouldn't have forced a refresh", response.forcedRefresh());
assertSearchHits(prepareSearch("test").setQuery(matchQuery("foo", "bar")), "1");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -670,8 +670,9 @@ public void testConcurrentCreateAndStatusAPICalls() throws Exception {
}, 60, TimeUnit.SECONDS);

for (ActionFuture<SnapshotsStatusResponse> status : statuses) {
assertThat(status.get().getSnapshots(), hasSize(snapshots));
for (SnapshotStatus snapshot : status.get().getSnapshots()) {
var statusResponse = status.get();
assertThat(statusResponse.getSnapshots(), hasSize(snapshots));
for (SnapshotStatus snapshot : statusResponse.getSnapshots()) {
assertThat(snapshot.getState(), allOf(not(SnapshotsInProgress.State.FAILED), not(SnapshotsInProgress.State.ABORTED)));
for (final var shard : snapshot.getShards()) {
if (shard.getStage() == SnapshotIndexShardStage.DONE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,14 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;

import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;

public abstract class AbstractClient implements Client {

Expand Down Expand Up @@ -361,7 +364,7 @@ public final <Request extends ActionRequest, Response extends ActionResponse> Ac
ActionType<Response> action,
Request request
) {
PlainActionFuture<Response> actionFuture = new PlainActionFuture<>();
PlainActionFuture<Response> actionFuture = new RefCountedFuture<>();
execute(action, request, actionFuture);
return actionFuture;
}
Expand Down Expand Up @@ -1598,4 +1601,34 @@ protected <Request extends ActionRequest, Response extends ActionResponse> void
}
};
}

/**
* Same as {@link PlainActionFuture} but for use with {@link RefCounted} result types. Unlike {@code PlainActionFuture} this future
* acquires a reference to its result. This means that the result reference must be released by a call to {@link RefCounted#decRef()}
* on the result before it goes out of scope.
* @param <R> reference counted result type
*/
private static class RefCountedFuture<R extends RefCounted> extends PlainActionFuture<R> {

@Override
public final void onResponse(R result) {
assert result.hasReferences();
if (set(result)) {
result.incRef();
}
}

private final AtomicBoolean getCalled = new AtomicBoolean(false);

@Override
public R get() throws InterruptedException, ExecutionException {
final boolean firstCall = getCalled.compareAndSet(false, true);
if (firstCall == false) {
final IllegalStateException ise = new IllegalStateException("must only call .get() once per instance to avoid leaks");
assert false : ise;
throw ise;
}
return super.get();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ protected <Request extends ActionRequest, Response extends ActionResponse> void
ActionListener<Response> listener
) {
assertEquals(origin, threadPool().getThreadContext().getTransient(ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME));
super.doExecute(action, request, listener);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ protected <Request extends ActionRequest, Response extends ActionResponse> void
ActionListener<Response> listener
) {
assertEquals(parentTaskId[0], request.getParentTask());
super.doExecute(action, request, listener);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ private void givenClearScrollRequest() {
doAnswer(invocationOnMock -> {
ActionListener<ClearScrollResponse> listener = (ActionListener<ClearScrollResponse>) invocationOnMock.getArguments()[2];
wasScrollCleared = true;
listener.onResponse(mock(ClearScrollResponse.class));
var clearScrollResponse = mock(ClearScrollResponse.class);
when(clearScrollResponse.hasReferences()).thenReturn(true);
listener.onResponse(clearScrollResponse);
return null;
}).when(client).execute(eq(ClearScrollAction.INSTANCE), any(), any());
}
Expand Down Expand Up @@ -171,6 +173,7 @@ ResponsesMocker addBatch(String... hits) {
protected SearchResponse createSearchResponseWithHits(String... hits) {
SearchHits searchHits = createHits(hits);
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.hasReferences()).thenReturn(true);
when(searchResponse.getScrollId()).thenReturn(SCROLL_ID);
when(searchResponse.getHits()).thenReturn(searchHits);
return searchResponse;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ void stopExecutor() {}

// response setup, successful refresh response
RefreshResponse refreshResponse = mock(RefreshResponse.class);
when(refreshResponse.hasReferences()).thenReturn(true);
when(refreshResponse.getSuccessfulShards()).thenReturn(
clusterState.getMetadata().getIndices().get(Watch.INDEX).getNumberOfShards()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ public void testFindTriggeredWatchesGoodCase() {
SearchResponse searchResponse1 = mock(SearchResponse.class);
when(searchResponse1.getSuccessfulShards()).thenReturn(1);
when(searchResponse1.getTotalShards()).thenReturn(1);
when(searchResponse1.hasReferences()).thenReturn(true);
BytesArray source = new BytesArray("{}");
SearchHit hit = new SearchHit(0, "first_foo");
hit.version(1L);
Expand Down Expand Up @@ -512,6 +513,7 @@ private RefreshResponse mockRefreshResponse(int total, int successful) {
RefreshResponse refreshResponse = mock(RefreshResponse.class);
when(refreshResponse.getTotalShards()).thenReturn(total);
when(refreshResponse.getSuccessfulShards()).thenReturn(successful);
when(refreshResponse.hasReferences()).thenReturn(true);
return refreshResponse;
}
}

0 comments on commit 659b236

Please sign in to comment.