Skip to content

Commit

Permalink
Introduce incremental reduction of TopDocs (#23946)
Browse files Browse the repository at this point in the history
This commit adds support for incremental top N reduction if the number of
expected shards in the search request is high enough. The changes here
also clean up more code in SearchPhaseController to make the separation
between values that are the same on each search result and values that
are per response. The reduced search phase result doesn't hold an arbitrary
result to obtain values like `from`, `size` or sort values which is now
cleanly encapsulated.
  • Loading branch information
s1monw authored Apr 10, 2017
1 parent b636ca7 commit 1f40f8a
Show file tree
Hide file tree
Showing 11 changed files with 490 additions and 227 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,27 +98,26 @@ private void innerRun() throws IOException {
final int numShards = context.getNumShards();
final boolean isScrollSearch = context.getRequest().scroll() != null;
List<SearchPhaseResult> phaseResults = queryResults.asList();
ScoreDoc[] sortedShardDocs = searchPhaseController.sortDocs(isScrollSearch, phaseResults);
String scrollId = isScrollSearch ? TransportSearchHelper.buildScrollId(queryResults) : null;
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = resultConsumer.reduce();
final boolean queryAndFetchOptimization = queryResults.length() == 1;
final Runnable finishPhase = ()
-> moveToNextPhase(searchPhaseController, sortedShardDocs, scrollId, reducedQueryPhase, queryAndFetchOptimization ?
-> moveToNextPhase(searchPhaseController, scrollId, reducedQueryPhase, queryAndFetchOptimization ?
queryResults : fetchResults);
if (queryAndFetchOptimization) {
assert phaseResults.isEmpty() || phaseResults.get(0).fetchResult() != null;
// query AND fetch optimization
finishPhase.run();
} else {
final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(numShards, sortedShardDocs);
if (sortedShardDocs.length == 0) { // no docs to fetch -- sidestep everything and return
final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(numShards, reducedQueryPhase.scoreDocs);
if (reducedQueryPhase.scoreDocs.length == 0) { // no docs to fetch -- sidestep everything and return
phaseResults.stream()
.map(e -> e.queryResult())
.forEach(this::releaseIrrelevantSearchContext); // we have to release contexts here to free up resources
finishPhase.run();
} else {
final ScoreDoc[] lastEmittedDocPerShard = isScrollSearch ?
searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, sortedShardDocs, numShards)
searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, numShards)
: null;
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(r -> fetchResults.set(r.getShardIndex(), r),
docIdsToLoad.length, // we count down every shard in the result no matter if we got any results or not
Expand Down Expand Up @@ -188,7 +187,7 @@ public void onFailure(Exception e) {
private void releaseIrrelevantSearchContext(QuerySearchResult queryResult) {
// we only release search context that we did not fetch from if we are not scrolling
// and if it has at lease one hit that didn't make it to the global topDocs
if (context.getRequest().scroll() == null && queryResult.hasHits()) {
if (context.getRequest().scroll() == null && queryResult.hasSearchContext()) {
try {
Transport.Connection connection = context.getConnection(queryResult.getSearchShardTarget().getNodeId());
context.sendReleaseSearchContext(queryResult.getRequestId(), connection);
Expand All @@ -198,11 +197,11 @@ private void releaseIrrelevantSearchContext(QuerySearchResult queryResult) {
}
}

private void moveToNextPhase(SearchPhaseController searchPhaseController, ScoreDoc[] sortedDocs,
private void moveToNextPhase(SearchPhaseController searchPhaseController,
String scrollId, SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
AtomicArray<? extends SearchPhaseResult> fetchResultsArr) {
final InternalSearchResponse internalResponse = searchPhaseController.merge(context.getRequest().scroll() != null,
sortedDocs, reducedQueryPhase, fetchResultsArr.asList(), fetchResultsArr::get);
reducedQueryPhase, fetchResultsArr.asList(), fetchResultsArr::get);
context.executeNextPhase(this, nextPhaseFactory.apply(context.buildSearchResponse(internalResponse, scrollId)));
}

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,8 @@ private void finishHim() {

private void innerFinishHim() throws Exception {
List<QueryFetchSearchResult> queryFetchSearchResults = queryFetchResults.asList();
ScoreDoc[] sortedShardDocs = searchPhaseController.sortDocs(true, queryFetchResults.asList());
final InternalSearchResponse internalResponse = searchPhaseController.merge(true, sortedShardDocs,
searchPhaseController.reducedQueryPhase(queryFetchSearchResults), queryFetchSearchResults, queryFetchResults::get);
final InternalSearchResponse internalResponse = searchPhaseController.merge(true,
searchPhaseController.reducedQueryPhase(queryFetchSearchResults, true), queryFetchSearchResults, queryFetchResults::get);
String scrollId = null;
if (request.scroll() != null) {
scrollId = request.scrollId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ final class SearchScrollQueryThenFetchAsyncAction extends AbstractAsyncAction {
private volatile AtomicArray<ShardSearchFailure> shardFailures;
final AtomicArray<QuerySearchResult> queryResults;
final AtomicArray<FetchSearchResult> fetchResults;
private volatile ScoreDoc[] sortedShardDocs;
private final AtomicInteger successfulOps;

SearchScrollQueryThenFetchAsyncAction(Logger logger, ClusterService clusterService, SearchTransportService searchTransportService,
Expand Down Expand Up @@ -171,16 +170,15 @@ void onQueryPhaseFailure(final int shardIndex, final CountDown counter, final lo
}

private void executeFetchPhase() throws Exception {
sortedShardDocs = searchPhaseController.sortDocs(true, queryResults.asList());
if (sortedShardDocs.length == 0) {
finishHim(searchPhaseController.reducedQueryPhase(queryResults.asList()));
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase(queryResults.asList(),
true);
if (reducedQueryPhase.scoreDocs.length == 0) {
finishHim(reducedQueryPhase);
return;
}

final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(queryResults.length(), sortedShardDocs);
SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase(queryResults.asList());
final ScoreDoc[] lastEmittedDocPerShard = searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, sortedShardDocs,
queryResults.length());
final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(queryResults.length(), reducedQueryPhase.scoreDocs);
final ScoreDoc[] lastEmittedDocPerShard = searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, queryResults.length());
final CountDown counter = new CountDown(docIdsToLoad.length);
for (int i = 0; i < docIdsToLoad.length; i++) {
final int index = i;
Expand Down Expand Up @@ -222,8 +220,8 @@ public void onFailure(Exception t) {

private void finishHim(SearchPhaseController.ReducedQueryPhase queryPhase) {
try {
final InternalSearchResponse internalResponse = searchPhaseController.merge(true, sortedShardDocs, queryPhase,
fetchResults.asList(), fetchResults::get);
final InternalSearchResponse internalResponse = searchPhaseController.merge(true, queryPhase, fetchResults.asList(),
fetchResults::get);
String scrollId = null;
if (request.scroll() != null) {
scrollId = request.scrollId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,6 @@
* to get the concrete values as a list using {@link #asList()}.
*/
public class AtomicArray<E> {

private static final AtomicArray EMPTY = new AtomicArray(0);

@SuppressWarnings("unchecked")
public static <E> E empty() {
return (E) EMPTY;
}

private final AtomicReferenceArray<E> array;
private volatile List<E> nonNullList;

Expand All @@ -53,7 +45,6 @@ public int length() {
return array.length();
}


/**
* Sets the element at position {@code i} to the given value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ public SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchTas

loadOrExecuteQueryPhase(request, context);

if (context.queryResult().hasHits() == false && context.scrollContext() == null) {
if (context.queryResult().hasSearchContext() == false && context.scrollContext() == null) {
freeContext(context.id());
} else {
contextProcessedSuccessfully(context);
Expand Down Expand Up @@ -341,7 +341,7 @@ public QuerySearchResult executeQueryPhase(QuerySearchRequest request, SearchTas
operationListener.onPreQueryPhase(context);
long time = System.nanoTime();
queryPhase.execute(context);
if (context.queryResult().hasHits() == false && context.scrollContext() == null) {
if (context.queryResult().hasSearchContext() == false && context.scrollContext() == null) {
// no hits, we can release the context since there will be no fetch phase
freeContext(context.id());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public void execute(SearchContext context) {
fetchSubPhase.hitsExecute(context, hits);
}

context.fetchResult().hits(new SearchHits(hits, context.queryResult().topDocs().totalHits, context.queryResult().topDocs().getMaxScore()));
context.fetchResult().hits(new SearchHits(hits, context.queryResult().getTotalHits(), context.queryResult().getMaxScore()));
}

private int findRootDocumentIfNested(SearchContext context, LeafReaderContext subReaderContext, int subDocId) throws IOException {
Expand Down
59 changes: 25 additions & 34 deletions core/src/main/java/org/elasticsearch/search/query/QueryPhase.java
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ static boolean execute(SearchContext searchContext, final IndexSearcher searcher
queryResult.searchTimedOut(false);

final boolean doProfile = searchContext.getProfilers() != null;
final SearchType searchType = searchContext.searchType();
boolean rescore = false;
try {
queryResult.from(searchContext.from());
Expand All @@ -165,12 +164,7 @@ static boolean execute(SearchContext searchContext, final IndexSearcher searcher
if (searchContext.getProfilers() != null) {
collector = new InternalProfileCollector(collector, CollectorResult.REASON_SEARCH_COUNT, Collections.emptyList());
}
topDocsCallable = new Callable<TopDocs>() {
@Override
public TopDocs call() throws Exception {
return new TopDocs(totalHitCountCollector.getTotalHits(), Lucene.EMPTY_SCORE_DOCS, 0);
}
};
topDocsCallable = () -> new TopDocs(totalHitCountCollector.getTotalHits(), Lucene.EMPTY_SCORE_DOCS, 0);
} else {
// Perhaps have a dedicated scroll phase?
final ScrollContext scrollContext = searchContext.scrollContext();
Expand Down Expand Up @@ -238,38 +232,35 @@ public TopDocs call() throws Exception {
if (doProfile) {
collector = new InternalProfileCollector(collector, CollectorResult.REASON_SEARCH_TOP_HITS, Collections.emptyList());
}
topDocsCallable = new Callable<TopDocs>() {
@Override
public TopDocs call() throws Exception {
final TopDocs topDocs;
if (topDocsCollector instanceof TopDocsCollector) {
topDocs = ((TopDocsCollector<?>) topDocsCollector).topDocs();
} else if (topDocsCollector instanceof CollapsingTopDocsCollector) {
topDocs = ((CollapsingTopDocsCollector) topDocsCollector).getTopDocs();
topDocsCallable = () -> {
final TopDocs topDocs;
if (topDocsCollector instanceof TopDocsCollector) {
topDocs = ((TopDocsCollector<?>) topDocsCollector).topDocs();
} else if (topDocsCollector instanceof CollapsingTopDocsCollector) {
topDocs = ((CollapsingTopDocsCollector) topDocsCollector).getTopDocs();
} else {
throw new IllegalStateException("Unknown top docs collector " + topDocsCollector.getClass().getName());
}
if (scrollContext != null) {
if (scrollContext.totalHits == -1) {
// first round
scrollContext.totalHits = topDocs.totalHits;
scrollContext.maxScore = topDocs.getMaxScore();
} else {
throw new IllegalStateException("Unknown top docs collector " + topDocsCollector.getClass().getName());
// subsequent round: the total number of hits and
// the maximum score were computed on the first round
topDocs.totalHits = scrollContext.totalHits;
topDocs.setMaxScore(scrollContext.maxScore);
}
if (scrollContext != null) {
if (scrollContext.totalHits == -1) {
// first round
scrollContext.totalHits = topDocs.totalHits;
scrollContext.maxScore = topDocs.getMaxScore();
} else {
// subsequent round: the total number of hits and
// the maximum score were computed on the first round
topDocs.totalHits = scrollContext.totalHits;
topDocs.setMaxScore(scrollContext.maxScore);
}
if (searchContext.request().numberOfShards() == 1) {
// if we fetch the document in the same roundtrip, we already know the last emitted doc
if (topDocs.scoreDocs.length > 0) {
// set the last emitted doc
scrollContext.lastEmittedDoc = topDocs.scoreDocs[topDocs.scoreDocs.length - 1];
}
if (searchContext.request().numberOfShards() == 1) {
// if we fetch the document in the same roundtrip, we already know the last emitted doc
if (topDocs.scoreDocs.length > 0) {
// set the last emitted doc
scrollContext.lastEmittedDoc = topDocs.scoreDocs[topDocs.scoreDocs.length - 1];
}
}
return topDocs;
}
return topDocs;
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ public final class QuerySearchResult extends SearchPhaseResult {
private Boolean terminatedEarly = null;
private ProfileShardResult profileShardResults;
private boolean hasProfileResults;
private boolean hasScoreDocs;
private int totalHits;
private float maxScore;

public QuerySearchResult() {
}
Expand Down Expand Up @@ -87,11 +90,34 @@ public Boolean terminatedEarly() {
}

public TopDocs topDocs() {
if (topDocs == null) {
throw new IllegalStateException("topDocs already consumed");
}
return topDocs;
}

/**
* Returns <code>true</code> iff the top docs have already been consumed.
*/
public boolean hasConsumedTopDocs() {
return topDocs == null;
}

/**
* Returns and nulls out the top docs for this search results. This allows to free up memory once the top docs are consumed.
* @throws IllegalStateException if the top docs have already been consumed.
*/
public TopDocs consumeTopDocs() {
TopDocs topDocs = this.topDocs;
if (topDocs == null) {
throw new IllegalStateException("topDocs already consumed");
}
this.topDocs = null;
return topDocs;
}

public void topDocs(TopDocs topDocs, DocValueFormat[] sortValueFormats) {
this.topDocs = topDocs;
setTopDocs(topDocs);
if (topDocs.scoreDocs.length > 0 && topDocs.scoreDocs[0] instanceof FieldDoc) {
int numFields = ((FieldDoc) topDocs.scoreDocs[0]).fields.length;
if (numFields != sortValueFormats.length) {
Expand All @@ -102,12 +128,19 @@ public void topDocs(TopDocs topDocs, DocValueFormat[] sortValueFormats) {
this.sortValueFormats = sortValueFormats;
}

private void setTopDocs(TopDocs topDocs) {
this.topDocs = topDocs;
hasScoreDocs = topDocs.scoreDocs.length > 0;
this.totalHits = topDocs.totalHits;
this.maxScore = topDocs.getMaxScore();
}

public DocValueFormat[] sortValueFormats() {
return sortValueFormats;
}

/**
* Retruns <code>true</code> if this query result has unconsumed aggregations
* Returns <code>true</code> if this query result has unconsumed aggregations
*/
public boolean hasAggs() {
return hasAggs;
Expand Down Expand Up @@ -195,10 +228,15 @@ public QuerySearchResult size(int size) {
return this;
}

/** Returns true iff the result has hits */
public boolean hasHits() {
return (topDocs != null && topDocs.scoreDocs.length > 0) ||
(suggest != null && suggest.hasScoreDocs());
/**
* Returns <code>true</code> if this result has any suggest score docs
*/
public boolean hasSuggestHits() {
return (suggest != null && suggest.hasScoreDocs());
}

public boolean hasSearchContext() {
return hasScoreDocs || hasSuggestHits();
}

public static QuerySearchResult readQuerySearchResult(StreamInput in) throws IOException {
Expand Down Expand Up @@ -227,7 +265,7 @@ public void readFromWithId(long id, StreamInput in) throws IOException {
sortValueFormats[i] = in.readNamedWriteable(DocValueFormat.class);
}
}
topDocs = readTopDocs(in);
setTopDocs(readTopDocs(in));
if (hasAggs = in.readBoolean()) {
aggregations = InternalAggregations.readAggregations(in);
}
Expand Down Expand Up @@ -278,4 +316,12 @@ public void writeToNoId(StreamOutput out) throws IOException {
out.writeOptionalBoolean(terminatedEarly);
out.writeOptionalWriteable(profileShardResults);
}

public int getTotalHits() {
return totalHits;
}

public float getMaxScore() {
return maxScore;
}
}
Loading

0 comments on commit 1f40f8a

Please sign in to comment.