Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rationalize ref-counting around ChannelActionListener (2nd attempt) #102638

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public ChannelActionListener(TransportChannel channel) {

@Override
public void onResponse(Response response) {
response.incRef(); // acquire reference that will be released by channel.sendResponse below
ActionListener.run(this, l -> l.channel.sendResponse(response));
}

Expand Down
30 changes: 27 additions & 3 deletions server/src/main/java/org/elasticsearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
Expand Down Expand Up @@ -544,7 +546,7 @@ public void executeQueryPhase(ShardSearchRequest request, SearchShardTask task,
}));
}

private <T> void ensureAfterSeqNoRefreshed(
private <T extends RefCounted> void ensureAfterSeqNoRefreshed(
IndexShard shard,
ShardSearchRequest request,
CheckedSupplier<T, Exception> executable,
Expand Down Expand Up @@ -648,8 +650,27 @@ private IndexShard getShard(ShardSearchRequest request) {
return indicesService.indexServiceSafe(request.shardId().getIndex()).getShard(request.shardId().id());
}

private static <T> void runAsync(Executor executor, CheckedSupplier<T, Exception> executable, ActionListener<T> listener) {
executor.execute(ActionRunnable.supply(listener, executable));
private static <T extends RefCounted> void runAsync(
Executor executor,
CheckedSupplier<T, Exception> executable,
ActionListener<T> listener
) {
executor.execute(ActionRunnable.wrap(listener, new CheckedConsumer<>() {
@Override
public void accept(ActionListener<T> l) throws Exception {
var res = executable.get();
try {
l.onResponse(res);
} finally {
res.decRef();
}
}

@Override
public String toString() {
return executable.toString();
}
}));
}

/**
Expand Down Expand Up @@ -686,6 +707,7 @@ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchSh
final RescoreDocIds rescoreDocIds = context.rescoreDocIds();
context.queryResult().setRescoreDocIds(rescoreDocIds);
readerContext.setRescoreDocIds(rescoreDocIds);
// inc-ref query result because we close the SearchContext that references it in this try-with-resources block
context.queryResult().incRef();
return context.queryResult();
}
Expand Down Expand Up @@ -783,6 +805,7 @@ public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task,
final RescoreDocIds rescoreDocIds = searchContext.rescoreDocIds();
queryResult.setRescoreDocIds(rescoreDocIds);
readerContext.setRescoreDocIds(rescoreDocIds);
// inc-ref query result because we close the SearchContext that references it in this try-with-resources block
queryResult.incRef();
return queryResult;
} catch (Exception e) {
Expand Down Expand Up @@ -866,6 +889,7 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A
executor.success();
}
var fetchResult = searchContext.fetchResult();
// inc-ref fetch result because we close the SearchContext that references it in this try-with-resources block
fetchResult.incRef();
return fetchResult;
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,10 @@ public void onFailure(Exception e) {
null
),
new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()),
result
result.delegateFailure((l, r) -> {
r.incRef();
l.onResponse(r);
})
);
final SearchPhaseResult searchPhaseResult = result.get();
try {
Expand All @@ -391,7 +394,7 @@ public void onFailure(Exception e) {
);
PlainActionFuture<FetchSearchResult> listener = new PlainActionFuture<>();
service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()), listener);
listener.get().decRef();
listener.get();
if (useScroll) {
// have to free context since this test does not remove the index from IndicesService.
service.freeReaderContext(searchPhaseResult.getContextId());
Expand Down Expand Up @@ -1048,7 +1051,6 @@ public void onResponse(SearchPhaseResult searchPhaseResult) {
// make sure that the wrapper is called when the query is actually executed
assertEquals(6, numWrapInvocations.get());
} finally {
searchPhaseResult.decRef();
latch.countDown();
}
}
Expand Down Expand Up @@ -1348,7 +1350,6 @@ public void onResponse(SearchPhaseResult result) {
assertNotNull(result.queryResult().topDocs());
assertNotNull(result.queryResult().aggregations());
} finally {
result.decRef();
latch.countDown();
}
}
Expand Down Expand Up @@ -1379,7 +1380,6 @@ public void onResponse(SearchPhaseResult result) {
assertNotNull(result.queryResult().topDocs());
assertNotNull(result.queryResult().aggregations());
} finally {
result.decRef();
latch.countDown();
}
}
Expand Down Expand Up @@ -1408,7 +1408,6 @@ public void onResponse(SearchPhaseResult result) {
assertThat(result, instanceOf(QuerySearchResult.class));
assertTrue(result.queryResult().isNull());
} finally {
result.decRef();
latch.countDown();
}
}
Expand Down Expand Up @@ -1549,7 +1548,6 @@ public void testCancelQueryPhaseEarly() throws Exception {
@Override
public void onResponse(SearchPhaseResult searchPhaseResult) {
service.freeReaderContext(searchPhaseResult.getContextId());
searchPhaseResult.decRef();
latch1.countDown();
}

Expand Down Expand Up @@ -1691,7 +1689,7 @@ public void onFailure(Exception e) {
client().clearScroll(clearScrollRequest);
}

public void testWaitOnRefresh() {
public void testWaitOnRefresh() throws ExecutionException, InterruptedException {
createIndex("index");
final SearchService service = getInstanceFromNode(SearchService.class);
final IndicesService indicesService = getInstanceFromNode(IndicesService.class);
Expand All @@ -1705,7 +1703,6 @@ public void testWaitOnRefresh() {
assertEquals(RestStatus.CREATED, response.status());

SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap());
PlainActionFuture<SearchPhaseResult> future = new PlainActionFuture<>();
ShardSearchRequest request = new ShardSearchRequest(
OriginalIndices.NONE,
searchRequest,
Expand All @@ -1719,13 +1716,12 @@ public void testWaitOnRefresh() {
null,
null
);
service.executeQueryPhase(request, task, future);
SearchPhaseResult searchPhaseResult = future.actionGet();
try {
assertEquals(1, searchPhaseResult.queryResult().getTotalHits().value);
} finally {
searchPhaseResult.decRef();
}
PlainActionFuture<Void> future = new PlainActionFuture<>();
service.executeQueryPhase(request, task, future.delegateFailure((l, r) -> {
assertEquals(1, r.queryResult().getTotalHits().value);
l.onResponse(null);
}));
future.get();
}

public void testWaitOnRefreshFailsWithRefreshesDisabled() {
Expand Down Expand Up @@ -1889,7 +1885,6 @@ public void testDfsQueryPhaseRewrite() {
-1,
null
);
PlainActionFuture<QuerySearchResult> plainActionFuture = new PlainActionFuture<>();
final Engine.SearcherSupplier reader = indexShard.acquireSearcherSupplier();
ReaderContext context = service.createAndPutReaderContext(
request,
Expand All @@ -1898,13 +1893,14 @@ public void testDfsQueryPhaseRewrite() {
reader,
SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis()
);
PlainActionFuture<QuerySearchResult> plainActionFuture = new PlainActionFuture<>();
service.executeQueryPhase(
new QuerySearchRequest(null, context.id(), request, new AggregatedDfs(Map.of(), Map.of(), 10)),
new SearchShardTask(42L, "", "", "", null, Collections.emptyMap()),
plainActionFuture
);

plainActionFuture.actionGet().decRef();
plainActionFuture.actionGet();
assertThat(((TestRewriteCounterQueryBuilder) request.source().query()).asyncRewriteCount, equalTo(1));
final ShardSearchContextId contextId = context.id();
assertTrue(service.freeReaderContext(contextId));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ protected void doExecute(
try (CcrRestoreSourceService.SessionReader sessionReader = restoreSourceService.getSessionReader(sessionUUID)) {
long offsetAfterRead = sessionReader.readFileBytes(fileName, reference);
long offsetBeforeRead = offsetAfterRead - reference.length();
listener.onResponse(new GetCcrRestoreFileChunkResponse(offsetBeforeRead, reference));
var chunk = new GetCcrRestoreFileChunkResponse(offsetBeforeRead, reference);
try {
listener.onResponse(chunk);
} finally {
chunk.decRef();
}
}
} catch (IOException e) {
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public void testRequestedShardIdMustBeConsistentWithSessionShardId() {
final PlainActionFuture<GetCcrRestoreFileChunkAction.GetCcrRestoreFileChunkResponse> future1 = new PlainActionFuture<>();
action.doExecute(mock(Task.class), request1, future1);
// The actual response content does not matter as long as the future executes without any error
future1.actionGet().decRef();
future1.actionGet();

// 2. Inconsistent requested ShardId
final var request2 = new GetCcrRestoreFileChunkRequest(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportResponse;

/**
* Wraps a {@link ChannelActionListener} and takes ownership of responses passed to
* {@link org.elasticsearch.action.ActionListener#onResponse(Object)}; the reference count will be decreased once sending is done.
*
* Deprecated: use {@link ChannelActionListener} instead and ensure responses sent to it are properly closed after.
*/
@Deprecated
public final class OwningChannelActionListener<Response extends TransportResponse> implements ActionListener<Response> {
private final ChannelActionListener<Response> listener;

public OwningChannelActionListener(TransportChannel channel) {
this.listener = new ChannelActionListener<>(channel);
}

@Override
public void onResponse(Response response) {
try {
listener.onResponse(response);
} finally {
response.decRef();
}
}

@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}

@Override
public String toString() {
return "OwningChannelActionListener{" + listener + "}";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.CompositeIndicesRequest;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.OwningChannelActionListener;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
Expand Down Expand Up @@ -117,7 +117,7 @@ public Status getStatus() {
private record DriverRequestHandler(TransportService transportService) implements TransportRequestHandler<DriverRequest> {
@Override
public void messageReceived(DriverRequest request, TransportChannel channel, Task task) {
var listener = new ChannelActionListener<TransportResponse.Empty>(channel);
var listener = new OwningChannelActionListener<TransportResponse.Empty>(channel);
Driver.start(
transportService.getThreadPool().getThreadContext(),
request.executor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.util.Objects;

public final class ExchangeResponse extends TransportResponse implements Releasable {
private final RefCounted counted = AbstractRefCounted.of(this::close);
private final RefCounted counted = AbstractRefCounted.of(this::closeInternal);
private final Page page;
private final boolean finished;
private boolean pageTaken;
Expand Down Expand Up @@ -98,6 +98,10 @@ public boolean hasReferences() {

@Override
public void close() {
counted.decRef();
}

private void closeInternal() {
if (pageTaken == false && page != null) {
page.releaseBlocks();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.component.Lifecycle;
Expand All @@ -22,6 +21,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.AbstractAsyncTask;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.compute.OwningChannelActionListener;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockStreamInput;
import org.elasticsearch.core.TimeValue;
Expand Down Expand Up @@ -193,7 +193,7 @@ private class ExchangeTransportAction implements TransportRequestHandler<Exchang
@Override
public void messageReceived(ExchangeRequest request, TransportChannel channel, Task task) {
final String exchangeId = request.exchangeId();
ActionListener<ExchangeResponse> listener = new ChannelActionListener<>(channel);
ActionListener<ExchangeResponse> listener = new OwningChannelActionListener<>(channel);
final ExchangeSinkHandler sinkHandler = sinks.get(exchangeId);
if (sinkHandler == null) {
listener.onResponse(new ExchangeResponse(null, true));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,6 @@ public void sendResponse(TransportResponse transportResponse) throws IOException
}
ExchangeResponse newResp = new ExchangeResponse(page, origResp.finished());
origResp.decRef();
while (origResp.hasReferences()) {
newResp.incRef();
origResp.decRef();
}
super.sendResponse(newResp);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.IndicesRequest;
import org.elasticsearch.action.UnavailableShardsException;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.ClusterState;
Expand All @@ -25,6 +24,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.compute.OwningChannelActionListener;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockStreamInput;
Expand Down Expand Up @@ -350,7 +350,7 @@ private class TransportHandler implements TransportRequestHandler<LookupRequest>
@Override
public void messageReceived(LookupRequest request, TransportChannel channel, Task task) {
request.incRef();
ActionListener<LookupResponse> listener = ActionListener.runBefore(new ChannelActionListener<>(channel), request::decRef);
ActionListener<LookupResponse> listener = ActionListener.runBefore(new OwningChannelActionListener<>(channel), request::decRef);
doLookup(
request.sessionId,
(CancellableTask) task,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.compute.OwningChannelActionListener;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportChannel;
Expand Down Expand Up @@ -110,7 +110,7 @@ public void messageReceived(ResolveRequest request, TransportChannel channel, Ta
String policyName = request.policyName;
EnrichPolicy policy = policies().get(policyName);
ThreadContext threadContext = threadPool.getThreadContext();
ActionListener<ResolveResponse> listener = new ChannelActionListener<>(channel);
ActionListener<ResolveResponse> listener = new OwningChannelActionListener<>(channel);
listener = ContextPreservingActionListener.wrapPreservingContext(listener, threadContext);
try (ThreadContext.StoredContext ignored = threadContext.stashWithOrigin(ClientHelper.ENRICH_ORIGIN)) {
indexResolver.resolveAsMergedMapping(
Expand Down
Loading