Skip to content

Commit

Permalink
Move LeakTracker to production code and use it in assertions around Q…
Browse files Browse the repository at this point in the history
…uerySearchContext (#102179)

Part of the ref counted search hits effort requires us to correctly ref
count `FetchSearchPhase`. This doesn't commit moves us one step in the
direction of doing so by adding testing that ensure that `QuerySearchContext`
is ref counted correctly and fixes one production code spot where it
wasn't (albeit that spot worked out for other reasons anyways).
This is done by moving the leak tracker to production code and making
use of it selectively in case assertions are enabled.
  • Loading branch information
original-brownbear authored Nov 14, 2023
1 parent 29aa74b commit aaa6e78
Show file tree
Hide file tree
Showing 13 changed files with 927 additions and 686 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.ShardSearchContextId;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.transport.LeakTracker;

import java.io.IOException;

Expand All @@ -26,14 +27,8 @@ public final class QueryFetchSearchResult extends SearchPhaseResult {
private final RefCounted refCounted;

public QueryFetchSearchResult(StreamInput in) throws IOException {
super(in);
// These get a ref count of 1 when we create them, so we don't need to incRef here
queryResult = new QuerySearchResult(in);
fetchResult = new FetchSearchResult(in);
refCounted = AbstractRefCounted.of(() -> {
queryResult.decRef();
fetchResult.decRef();
});
this(new QuerySearchResult(in), new FetchSearchResult(in));
}

public QueryFetchSearchResult(QuerySearchResult queryResult, FetchSearchResult fetchResult) {
Expand All @@ -42,10 +37,10 @@ public QueryFetchSearchResult(QuerySearchResult queryResult, FetchSearchResult f
// We're acquiring a copy, we should incRef it
this.queryResult.incRef();
this.fetchResult.incRef();
refCounted = AbstractRefCounted.of(() -> {
refCounted = LeakTracker.wrap(AbstractRefCounted.of(() -> {
queryResult.decRef();
fetchResult.decRef();
});
}));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,14 @@ static void executeRank(SearchContext searchContext) throws QueryPhaseExecutionE
if (searchTimedOut) {
break;
}
RankSearchContext rankSearchContext = new RankSearchContext(searchContext, rankQuery, rankShardContext.windowSize());
QueryPhase.addCollectorsAndSearch(rankSearchContext);
QuerySearchResult rrfQuerySearchResult = rankSearchContext.queryResult();
rrfRankResults.add(rrfQuerySearchResult.topDocs().topDocs);
serviceTimeEWMA += rrfQuerySearchResult.serviceTimeEWMA();
nodeQueueSize = Math.max(nodeQueueSize, rrfQuerySearchResult.nodeQueueSize());
searchTimedOut = rrfQuerySearchResult.searchTimedOut();
try (RankSearchContext rankSearchContext = new RankSearchContext(searchContext, rankQuery, rankShardContext.windowSize())) {
QueryPhase.addCollectorsAndSearch(rankSearchContext);
QuerySearchResult rrfQuerySearchResult = rankSearchContext.queryResult();
rrfRankResults.add(rrfQuerySearchResult.topDocs().topDocs);
serviceTimeEWMA += rrfQuerySearchResult.serviceTimeEWMA();
nodeQueueSize = Math.max(nodeQueueSize, rrfQuerySearchResult.nodeQueueSize());
searchTimedOut = rrfQuerySearchResult.searchTimedOut();
}
}

querySearchResult.setRankShardResult(rankShardContext.combine(rrfRankResults));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.search.profile.SearchProfileQueryPhaseResult;
import org.elasticsearch.search.rank.RankShardResult;
import org.elasticsearch.search.suggest.Suggest;
import org.elasticsearch.transport.LeakTracker;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -104,8 +105,8 @@ public QuerySearchResult(ShardSearchContextId contextId, SearchShardTarget shard
setSearchShardTarget(shardTarget);
isNull = false;
setShardSearchRequest(shardSearchRequest);
this.refCounted = AbstractRefCounted.of(this::close);
this.toRelease = new ArrayList<>();
this.refCounted = LeakTracker.wrap(AbstractRefCounted.of(() -> Releasables.close(toRelease)));
}

private QuerySearchResult(boolean isNull) {
Expand Down Expand Up @@ -245,10 +246,6 @@ public void releaseAggs() {
}
}

private void close() {
Releasables.close(toRelease);
}

public void addReleasable(Releasable releasable) {
toRelease.add(releasable);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public RankSearchContext(SearchContext parent, Query rankQuery, int windowSize)
this.rankQuery = parent.buildFilteredQuery(rankQuery);
this.windowSize = windowSize;
this.querySearchResult = new QuerySearchResult(parent.readerContext().id(), parent.shardTarget(), parent.request());
this.addReleasable(querySearchResult::decRef);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Assertions;
import org.elasticsearch.core.RefCounted;

import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
Expand Down Expand Up @@ -69,6 +71,41 @@ public void reportLeak() {
}
}

public static RefCounted wrap(RefCounted refCounted) {
if (Assertions.ENABLED == false) {
return refCounted;
}
var leak = INSTANCE.track(refCounted);
return new RefCounted() {
@Override
public void incRef() {
leak.record();
refCounted.incRef();
}

@Override
public boolean tryIncRef() {
leak.record();
return refCounted.tryIncRef();
}

@Override
public boolean decRef() {
if (refCounted.decRef()) {
leak.close(refCounted);
return true;
}
leak.record();
return false;
}

@Override
public boolean hasReferences() {
return refCounted.hasReferences();
}
};
}

public static final class Leak<T> extends WeakReference<Object> {

@SuppressWarnings({ "unchecked", "rawtypes" })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,30 +81,38 @@ public void sendExecuteQuery(
new SearchShardTarget("node1", new ShardId("test", "na", 0), null),
null
);
queryResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }),
2.0F
),
new DocValueFormat[0]
);
queryResult.size(2); // the size of the result set
listener.onResponse(queryResult);
try {
queryResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }),
2.0F
),
new DocValueFormat[0]
);
queryResult.size(2); // the size of the result set
listener.onResponse(queryResult);
} finally {
queryResult.decRef();
}
} else if (request.contextId().getId() == 2) {
QuerySearchResult queryResult = new QuerySearchResult(
new ShardSearchContextId("", 123),
new SearchShardTarget("node2", new ShardId("test", "na", 0), null),
null
);
queryResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }),
2.0F
),
new DocValueFormat[0]
);
queryResult.size(2); // the size of the result set
listener.onResponse(queryResult);
try {
queryResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }),
2.0F
),
new DocValueFormat[0]
);
queryResult.size(2); // the size of the result set
listener.onResponse(queryResult);
} finally {
queryResult.decRef();
}
} else {
fail("no such request ID: " + request.contextId());
}
Expand Down Expand Up @@ -172,15 +180,19 @@ public void sendExecuteQuery(
new SearchShardTarget("node1", new ShardId("test", "na", 0), null),
null
);
queryResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }),
2.0F
),
new DocValueFormat[0]
);
queryResult.size(2); // the size of the result set
listener.onResponse(queryResult);
try {
queryResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }),
2.0F
),
new DocValueFormat[0]
);
queryResult.size(2); // the size of the result set
listener.onResponse(queryResult);
} finally {
queryResult.decRef();
}
} else if (request.contextId().getId() == 2) {
listener.onFailure(new MockDirectoryWrapper.FakeIOException());
} else {
Expand Down Expand Up @@ -252,15 +264,19 @@ public void sendExecuteQuery(
new SearchShardTarget("node1", new ShardId("test", "na", 0), null),
null
);
queryResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }),
2.0F
),
new DocValueFormat[0]
);
queryResult.size(2); // the size of the result set
listener.onResponse(queryResult);
try {
queryResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }),
2.0F
),
new DocValueFormat[0]
);
queryResult.size(2); // the size of the result set
listener.onResponse(queryResult);
} finally {
queryResult.decRef();
}
} else if (request.contextId().getId() == 2) {
throw new UncheckedIOException(new MockDirectoryWrapper.FakeIOException());
} else {
Expand Down
Loading

0 comments on commit aaa6e78

Please sign in to comment.