Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rationalize ref-counting around ChannelActionListener #102551

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public ChannelActionListener(TransportChannel channel) {

@Override
public void onResponse(Response response) {
response.incRef(); // acquire reference that will be released by channel.sendResponse below
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On reflection I think it'd be nicer to push this down a level or two and make TransportChannel#sendResponse refcount-neutral. In particular it's kinda weird that OutboundHandler#sendRequest is already refcount-neutral whilst OutboundHandler#sendResponse is not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea it sucks but could we push that out to another PR? That's practically a rather big change it seems, just tried it. We have a number of spots outside of this listener where we call the channel.sendResponse() directly and I'd have to adjust all of those. That goes beyond the scope of what I'm trying to fix here. Note that this at least gets us closer to fixing the OutBoundHandler :) ok with you?

Copy link
Contributor

@DaveCTurner DaveCTurner Nov 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah ok makes sense - I opened #102600 to make sure we don't forget

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
response.incRef(); // acquire reference that will be released by channel.sendResponse below
response.mustIncRef(); // acquire reference that will be released by channel.sendResponse below

ActionListener.run(this, l -> l.channel.sendResponse(response));
}

Expand Down
30 changes: 27 additions & 3 deletions server/src/main/java/org/elasticsearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
Expand Down Expand Up @@ -544,7 +546,7 @@ public void executeQueryPhase(ShardSearchRequest request, SearchShardTask task,
}));
}

private <T> void ensureAfterSeqNoRefreshed(
private <T extends RefCounted> void ensureAfterSeqNoRefreshed(
IndexShard shard,
ShardSearchRequest request,
CheckedSupplier<T, Exception> executable,
Expand Down Expand Up @@ -648,8 +650,27 @@ private IndexShard getShard(ShardSearchRequest request) {
return indicesService.indexServiceSafe(request.shardId().getIndex()).getShard(request.shardId().id());
}

private static <T> void runAsync(Executor executor, CheckedSupplier<T, Exception> executable, ActionListener<T> listener) {
executor.execute(ActionRunnable.supply(listener, executable));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the problem, we implicitly assumed that the listener would always count down the reference here. This held for all current implementations which saw a ChannelActionListener but broke once I tried to use this with the multi-search response which can either go to a ChannelActionListener or some other listeners (which would then have to be aware of the fact that they need to count-down by one which would make the code needlessly brittle.

private static <T extends RefCounted> void runAsync(
Executor executor,
CheckedSupplier<T, Exception> executable,
ActionListener<T> listener
) {
executor.execute(ActionRunnable.wrap(listener, new CheckedConsumer<>() {
@Override
public void accept(ActionListener<T> l) throws Exception {
var res = executable.get();
try {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, following this pattern (instantiate -> try -> finally -> decRef) as closely as possible and not implicitly requiring listeners to ever count down a ref makes the code a lot easier to follow and extend (as can be seen in the test simplifications ... mod the futures :P).

l.onResponse(res);
} finally {
res.decRef();
}
}

@Override
public String toString() {
return executable.toString();
}
}));
}

/**
Expand Down Expand Up @@ -686,6 +707,7 @@ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchSh
final RescoreDocIds rescoreDocIds = context.rescoreDocIds();
context.queryResult().setRescoreDocIds(rescoreDocIds);
readerContext.setRescoreDocIds(rescoreDocIds);
// inc-ref query result because we close the SearchContext that references it in this try-with-resources block
context.queryResult().incRef();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
context.queryResult().incRef();
context.queryResult().mustIncRef();

return context.queryResult();
}
Expand Down Expand Up @@ -783,6 +805,7 @@ public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task,
final RescoreDocIds rescoreDocIds = searchContext.rescoreDocIds();
queryResult.setRescoreDocIds(rescoreDocIds);
readerContext.setRescoreDocIds(rescoreDocIds);
// inc-ref query result because we close the SearchContext that references it in this try-with-resources block
queryResult.incRef();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
queryResult.incRef();
queryResult.mustIncRef();

return queryResult;
} catch (Exception e) {
Expand Down Expand Up @@ -866,6 +889,7 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A
executor.success();
}
var fetchResult = searchContext.fetchResult();
// inc-ref fetch result because we close the SearchContext that references it in this try-with-resources block
fetchResult.incRef();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fetchResult.incRef();
fetchResult.mustIncRef();

return fetchResult;
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,10 @@ public void onFailure(Exception e) {
null
),
new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()),
result
result.delegateFailure((l, r) -> {
r.incRef();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kind of hack is fine, no need to do anything fancy, we never use this transport action with a future in prod.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
r.incRef();
r.mustIncRef();

l.onResponse(r);
})
);
final SearchPhaseResult searchPhaseResult = result.get();
try {
Expand All @@ -391,7 +394,7 @@ public void onFailure(Exception e) {
);
PlainActionFuture<FetchSearchResult> listener = new PlainActionFuture<>();
service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()), listener);
listener.get().decRef();
listener.get();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, this doesn't need to return anything that has > 0 ref count, just using the future to wait on the action here.

if (useScroll) {
// have to free context since this test does not remove the index from IndicesService.
service.freeReaderContext(searchPhaseResult.getContextId());
Expand Down Expand Up @@ -1048,7 +1051,6 @@ public void onResponse(SearchPhaseResult searchPhaseResult) {
// make sure that the wrapper is called when the query is actually executed
assertEquals(6, numWrapInvocations.get());
} finally {
searchPhaseResult.decRef();
latch.countDown();
}
}
Expand Down Expand Up @@ -1348,7 +1350,6 @@ public void onResponse(SearchPhaseResult result) {
assertNotNull(result.queryResult().topDocs());
assertNotNull(result.queryResult().aggregations());
} finally {
result.decRef();
latch.countDown();
}
}
Expand Down Expand Up @@ -1379,7 +1380,6 @@ public void onResponse(SearchPhaseResult result) {
assertNotNull(result.queryResult().topDocs());
assertNotNull(result.queryResult().aggregations());
} finally {
result.decRef();
latch.countDown();
}
}
Expand Down Expand Up @@ -1408,7 +1408,6 @@ public void onResponse(SearchPhaseResult result) {
assertThat(result, instanceOf(QuerySearchResult.class));
assertTrue(result.queryResult().isNull());
} finally {
result.decRef();
latch.countDown();
}
}
Expand Down Expand Up @@ -1549,7 +1548,6 @@ public void testCancelQueryPhaseEarly() throws Exception {
@Override
public void onResponse(SearchPhaseResult searchPhaseResult) {
service.freeReaderContext(searchPhaseResult.getContextId());
searchPhaseResult.decRef();
latch1.countDown();
}

Expand Down Expand Up @@ -1691,7 +1689,7 @@ public void onFailure(Exception e) {
client().clearScroll(clearScrollRequest);
}

public void testWaitOnRefresh() {
public void testWaitOnRefresh() throws ExecutionException, InterruptedException {
createIndex("index");
final SearchService service = getInstanceFromNode(SearchService.class);
final IndicesService indicesService = getInstanceFromNode(IndicesService.class);
Expand All @@ -1705,7 +1703,6 @@ public void testWaitOnRefresh() {
assertEquals(RestStatus.CREATED, response.status());

SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap());
PlainActionFuture<SearchPhaseResult> future = new PlainActionFuture<>();
ShardSearchRequest request = new ShardSearchRequest(
OriginalIndices.NONE,
searchRequest,
Expand All @@ -1719,13 +1716,12 @@ public void testWaitOnRefresh() {
null,
null
);
service.executeQueryPhase(request, task, future);
SearchPhaseResult searchPhaseResult = future.actionGet();
try {
assertEquals(1, searchPhaseResult.queryResult().getTotalHits().value);
} finally {
searchPhaseResult.decRef();
}
PlainActionFuture<Void> future = new PlainActionFuture<>();
service.executeQueryPhase(request, task, future.delegateFailure((l, r) -> {
assertEquals(1, r.queryResult().getTotalHits().value);
l.onResponse(null);
}));
future.get();
}

public void testWaitOnRefreshFailsWithRefreshesDisabled() {
Expand Down Expand Up @@ -1889,7 +1885,6 @@ public void testDfsQueryPhaseRewrite() {
-1,
null
);
PlainActionFuture<QuerySearchResult> plainActionFuture = new PlainActionFuture<>();
final Engine.SearcherSupplier reader = indexShard.acquireSearcherSupplier();
ReaderContext context = service.createAndPutReaderContext(
request,
Expand All @@ -1898,13 +1893,14 @@ public void testDfsQueryPhaseRewrite() {
reader,
SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis()
);
PlainActionFuture<QuerySearchResult> plainActionFuture = new PlainActionFuture<>();
service.executeQueryPhase(
new QuerySearchRequest(null, context.id(), request, new AggregatedDfs(Map.of(), Map.of(), 10)),
new SearchShardTask(42L, "", "", "", null, Collections.emptyMap()),
plainActionFuture
);

plainActionFuture.actionGet().decRef();
plainActionFuture.actionGet();
assertThat(((TestRewriteCounterQueryBuilder) request.source().query()).asyncRewriteCount, equalTo(1));
final ShardSearchContextId contextId = context.id();
assertTrue(service.freeReaderContext(contextId));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ protected void doExecute(
try (CcrRestoreSourceService.SessionReader sessionReader = restoreSourceService.getSessionReader(sessionUUID)) {
long offsetAfterRead = sessionReader.readFileBytes(fileName, reference);
long offsetBeforeRead = offsetAfterRead - reference.length();
listener.onResponse(new GetCcrRestoreFileChunkResponse(offsetBeforeRead, reference));
var chunk = new GetCcrRestoreFileChunkResponse(offsetBeforeRead, reference);
try {
listener.onResponse(chunk);
} finally {
chunk.decRef();
}
}
} catch (IOException e) {
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public void testRequestedShardIdMustBeConsistentWithSessionShardId() {
final PlainActionFuture<GetCcrRestoreFileChunkAction.GetCcrRestoreFileChunkResponse> future1 = new PlainActionFuture<>();
action.doExecute(mock(Task.class), request1, future1);
// The actual response content does not matter as long as the future executes without any error
future1.actionGet().decRef();
future1.actionGet();

// 2. Inconsistent requested ShardId
final var request2 = new GetCcrRestoreFileChunkRequest(
Expand Down