Skip to content

Commit

Permalink
Rationalize ref-counting around ChannelActionListener (elastic#102551)
Browse files Browse the repository at this point in the history
This listener was behaving differently from other listeners in that it
would effectively result in decRef by 1 when invoked. This worked out so
far by accounting for this fact in calling code but is not maintainable
now that more and more ref-counted things are going to be passed into
it => make it behave like any other listener and be neutral in respect
to ref count by acquiring the reference that the channel will release.

Co-authored-by: Alexander Spies <alexander.spies@elastic.co>
  • Loading branch information
original-brownbear and alex-spies authored Nov 24, 2023
1 parent 7def2e0 commit 6f72a1c
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public ChannelActionListener(TransportChannel channel) {

@Override
public void onResponse(Response response) {
response.incRef(); // acquire reference that will be released by channel.sendResponse below
ActionListener.run(this, l -> l.channel.sendResponse(response));
}

Expand Down
30 changes: 27 additions & 3 deletions server/src/main/java/org/elasticsearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
Expand Down Expand Up @@ -544,7 +546,7 @@ public void executeQueryPhase(ShardSearchRequest request, SearchShardTask task,
}));
}

private <T> void ensureAfterSeqNoRefreshed(
private <T extends RefCounted> void ensureAfterSeqNoRefreshed(
IndexShard shard,
ShardSearchRequest request,
CheckedSupplier<T, Exception> executable,
Expand Down Expand Up @@ -648,8 +650,27 @@ private IndexShard getShard(ShardSearchRequest request) {
return indicesService.indexServiceSafe(request.shardId().getIndex()).getShard(request.shardId().id());
}

private static <T> void runAsync(Executor executor, CheckedSupplier<T, Exception> executable, ActionListener<T> listener) {
executor.execute(ActionRunnable.supply(listener, executable));
private static <T extends RefCounted> void runAsync(
Executor executor,
CheckedSupplier<T, Exception> executable,
ActionListener<T> listener
) {
executor.execute(ActionRunnable.wrap(listener, new CheckedConsumer<>() {
@Override
public void accept(ActionListener<T> l) throws Exception {
var res = executable.get();
try {
l.onResponse(res);
} finally {
res.decRef();
}
}

@Override
public String toString() {
return executable.toString();
}
}));
}

/**
Expand Down Expand Up @@ -686,6 +707,7 @@ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchSh
final RescoreDocIds rescoreDocIds = context.rescoreDocIds();
context.queryResult().setRescoreDocIds(rescoreDocIds);
readerContext.setRescoreDocIds(rescoreDocIds);
// inc-ref query result because we close the SearchContext that references it in this try-with-resources block
context.queryResult().incRef();
return context.queryResult();
}
Expand Down Expand Up @@ -783,6 +805,7 @@ public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task,
final RescoreDocIds rescoreDocIds = searchContext.rescoreDocIds();
queryResult.setRescoreDocIds(rescoreDocIds);
readerContext.setRescoreDocIds(rescoreDocIds);
// inc-ref query result because we close the SearchContext that references it in this try-with-resources block
queryResult.incRef();
return queryResult;
} catch (Exception e) {
Expand Down Expand Up @@ -866,6 +889,7 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A
executor.success();
}
var fetchResult = searchContext.fetchResult();
// inc-ref fetch result because we close the SearchContext that references it in this try-with-resources block
fetchResult.incRef();
return fetchResult;
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,10 @@ public void onFailure(Exception e) {
null
),
new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()),
result
result.delegateFailure((l, r) -> {
r.incRef();
l.onResponse(r);
})
);
final SearchPhaseResult searchPhaseResult = result.get();
try {
Expand All @@ -391,7 +394,7 @@ public void onFailure(Exception e) {
);
PlainActionFuture<FetchSearchResult> listener = new PlainActionFuture<>();
service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()), listener);
listener.get().decRef();
listener.get();
if (useScroll) {
// have to free context since this test does not remove the index from IndicesService.
service.freeReaderContext(searchPhaseResult.getContextId());
Expand Down Expand Up @@ -1048,7 +1051,6 @@ public void onResponse(SearchPhaseResult searchPhaseResult) {
// make sure that the wrapper is called when the query is actually executed
assertEquals(6, numWrapInvocations.get());
} finally {
searchPhaseResult.decRef();
latch.countDown();
}
}
Expand Down Expand Up @@ -1348,7 +1350,6 @@ public void onResponse(SearchPhaseResult result) {
assertNotNull(result.queryResult().topDocs());
assertNotNull(result.queryResult().aggregations());
} finally {
result.decRef();
latch.countDown();
}
}
Expand Down Expand Up @@ -1379,7 +1380,6 @@ public void onResponse(SearchPhaseResult result) {
assertNotNull(result.queryResult().topDocs());
assertNotNull(result.queryResult().aggregations());
} finally {
result.decRef();
latch.countDown();
}
}
Expand Down Expand Up @@ -1408,7 +1408,6 @@ public void onResponse(SearchPhaseResult result) {
assertThat(result, instanceOf(QuerySearchResult.class));
assertTrue(result.queryResult().isNull());
} finally {
result.decRef();
latch.countDown();
}
}
Expand Down Expand Up @@ -1549,7 +1548,6 @@ public void testCancelQueryPhaseEarly() throws Exception {
@Override
public void onResponse(SearchPhaseResult searchPhaseResult) {
service.freeReaderContext(searchPhaseResult.getContextId());
searchPhaseResult.decRef();
latch1.countDown();
}

Expand Down Expand Up @@ -1691,7 +1689,7 @@ public void onFailure(Exception e) {
client().clearScroll(clearScrollRequest);
}

public void testWaitOnRefresh() {
public void testWaitOnRefresh() throws ExecutionException, InterruptedException {
createIndex("index");
final SearchService service = getInstanceFromNode(SearchService.class);
final IndicesService indicesService = getInstanceFromNode(IndicesService.class);
Expand All @@ -1705,7 +1703,6 @@ public void testWaitOnRefresh() {
assertEquals(RestStatus.CREATED, response.status());

SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap());
PlainActionFuture<SearchPhaseResult> future = new PlainActionFuture<>();
ShardSearchRequest request = new ShardSearchRequest(
OriginalIndices.NONE,
searchRequest,
Expand All @@ -1719,13 +1716,12 @@ public void testWaitOnRefresh() {
null,
null
);
service.executeQueryPhase(request, task, future);
SearchPhaseResult searchPhaseResult = future.actionGet();
try {
assertEquals(1, searchPhaseResult.queryResult().getTotalHits().value);
} finally {
searchPhaseResult.decRef();
}
PlainActionFuture<Void> future = new PlainActionFuture<>();
service.executeQueryPhase(request, task, future.delegateFailure((l, r) -> {
assertEquals(1, r.queryResult().getTotalHits().value);
l.onResponse(null);
}));
future.get();
}

public void testWaitOnRefreshFailsWithRefreshesDisabled() {
Expand Down Expand Up @@ -1889,7 +1885,6 @@ public void testDfsQueryPhaseRewrite() {
-1,
null
);
PlainActionFuture<QuerySearchResult> plainActionFuture = new PlainActionFuture<>();
final Engine.SearcherSupplier reader = indexShard.acquireSearcherSupplier();
ReaderContext context = service.createAndPutReaderContext(
request,
Expand All @@ -1898,13 +1893,14 @@ public void testDfsQueryPhaseRewrite() {
reader,
SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis()
);
PlainActionFuture<QuerySearchResult> plainActionFuture = new PlainActionFuture<>();
service.executeQueryPhase(
new QuerySearchRequest(null, context.id(), request, new AggregatedDfs(Map.of(), Map.of(), 10)),
new SearchShardTask(42L, "", "", "", null, Collections.emptyMap()),
plainActionFuture
);

plainActionFuture.actionGet().decRef();
plainActionFuture.actionGet();
assertThat(((TestRewriteCounterQueryBuilder) request.source().query()).asyncRewriteCount, equalTo(1));
final ShardSearchContextId contextId = context.id();
assertTrue(service.freeReaderContext(contextId));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ protected void doExecute(
try (CcrRestoreSourceService.SessionReader sessionReader = restoreSourceService.getSessionReader(sessionUUID)) {
long offsetAfterRead = sessionReader.readFileBytes(fileName, reference);
long offsetBeforeRead = offsetAfterRead - reference.length();
listener.onResponse(new GetCcrRestoreFileChunkResponse(offsetBeforeRead, reference));
var chunk = new GetCcrRestoreFileChunkResponse(offsetBeforeRead, reference);
try {
listener.onResponse(chunk);
} finally {
chunk.decRef();
}
}
} catch (IOException e) {
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public void testRequestedShardIdMustBeConsistentWithSessionShardId() {
final PlainActionFuture<GetCcrRestoreFileChunkAction.GetCcrRestoreFileChunkResponse> future1 = new PlainActionFuture<>();
action.doExecute(mock(Task.class), request1, future1);
// The actual response content does not matter as long as the future executes without any error
future1.actionGet().decRef();
future1.actionGet();

// 2. Inconsistent requested ShardId
final var request2 = new GetCcrRestoreFileChunkRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.util.Objects;

public final class ExchangeResponse extends TransportResponse implements Releasable {
private final RefCounted counted = AbstractRefCounted.of(this::close);
private final RefCounted counted = AbstractRefCounted.of(this::closeInternal);
private final Page page;
private final boolean finished;
private boolean pageTaken;
Expand Down Expand Up @@ -98,6 +98,10 @@ public boolean hasReferences() {

@Override
public void close() {
counted.decRef();
}

private void closeInternal() {
if (pageTaken == false && page != null) {
page.releaseBlocks();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,10 @@ private void notifyListeners() {
} finally {
promised.release();
}
onChanged();
listener.onResponse(response);
try (response) {
onChanged();
listener.onResponse(response);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,6 @@ public void sendResponse(TransportResponse transportResponse) throws IOException
}
ExchangeResponse newResp = new ExchangeResponse(page, origResp.finished());
origResp.decRef();
while (origResp.hasReferences()) {
newResp.incRef();
origResp.decRef();
}
super.sendResponse(newResp);
}
};
Expand Down

0 comments on commit 6f72a1c

Please sign in to comment.