Skip to content

Commit

Permalink
First step towards incremental reduction of query responses (#23253)
Browse files Browse the repository at this point in the history
Today all query results are buffered up until we received responses of
all shards. This can hold on to a significant amount of memory if the number of
shards is large. This commit adds a first step towards incrementally reducing
aggregations results if a, per search request, configurable amount of responses
are received. If enough query results have been received and buffered all so-far
received aggregation responses will be reduced and released to be GCed.
  • Loading branch information
s1monw authored Feb 21, 2017
1 parent 39ed76c commit f933f80
Show file tree
Hide file tree
Showing 33 changed files with 588 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

import java.util.List;
import java.util.Map;
import java.util.StringJoiner;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
Expand All @@ -61,7 +60,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
**/
private final Function<String, Transport.Connection> nodeIdToConnection;
private final SearchTask task;
private final AtomicArray<Result> results;
private final SearchPhaseResults<Result> results;
private final long clusterStateVersion;
private final Map<String, AliasFilter> aliasFilter;
private final Map<String, Float> concreteIndexBoosts;
Expand All @@ -76,7 +75,7 @@ protected AbstractSearchAsyncAction(String name, Logger logger, SearchTransportS
Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts,
Executor executor, SearchRequest request,
ActionListener<SearchResponse> listener, GroupShardsIterator shardsIts, long startTime,
long clusterStateVersion, SearchTask task) {
long clusterStateVersion, SearchTask task, SearchPhaseResults<Result> resultConsumer) {
super(name, request, shardsIts, logger);
this.startTime = startTime;
this.logger = logger;
Expand All @@ -87,9 +86,9 @@ protected AbstractSearchAsyncAction(String name, Logger logger, SearchTransportS
this.listener = listener;
this.nodeIdToConnection = nodeIdToConnection;
this.clusterStateVersion = clusterStateVersion;
results = new AtomicArray<>(shardsIts.size());
this.concreteIndexBoosts = concreteIndexBoosts;
this.aliasFilter = aliasFilter;
this.results = resultConsumer;
}

/**
Expand All @@ -105,7 +104,7 @@ private long buildTookInMillis() {
* This is the main entry point for a search. This method starts the search execution of the initial phase.
*/
public final void start() {
if (results.length() == 0) {
if (getNumShards() == 0) {
//no search shards to search on, bail with empty response
//(it happens with search across _all with no indices around and consistent with broadcast operations)
listener.onResponse(new SearchResponse(InternalSearchResponse.empty(), null, 0, 0, buildTookInMillis(),
Expand All @@ -130,8 +129,8 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha
onPhaseFailure(currentPhase, "all shards failed", null);
} else {
if (logger.isTraceEnabled()) {
final String resultsFrom = results.asList().stream()
.map(r -> r.value.shardTarget().toString()).collect(Collectors.joining(","));
final String resultsFrom = results.getSuccessfulResults()
.map(r -> r.shardTarget().toString()).collect(Collectors.joining(","));
logger.trace("[{}] Moving to next phase: [{}], based on results from: {} (cluster state version: {})",
currentPhase.getName(), nextPhase.getName(), resultsFrom, clusterStateVersion);
}
Expand Down Expand Up @@ -178,7 +177,7 @@ public final void onShardFailure(final int shardIndex, @Nullable SearchShardTarg
synchronized (shardFailuresMutex) {
shardFailures = this.shardFailures.get(); // read again otherwise somebody else has created it?
if (shardFailures == null) { // still null so we are the first and create a new instance
shardFailures = new AtomicArray<>(results.length());
shardFailures = new AtomicArray<>(getNumShards());
this.shardFailures.set(shardFailures);
}
}
Expand All @@ -194,7 +193,7 @@ public final void onShardFailure(final int shardIndex, @Nullable SearchShardTarg
}
}

if (results.get(shardIndex) != null) {
if (results.hasResult(shardIndex)) {
assert failure == null : "shard failed before but shouldn't: " + failure;
successfulOps.decrementAndGet(); // if this shard was successful before (initial phase) we have to adjust the counter
}
Expand All @@ -207,22 +206,22 @@ public final void onShardFailure(final int shardIndex, @Nullable SearchShardTarg
* @param exception the exception explaining or causing the phase failure
*/
private void raisePhaseFailure(SearchPhaseExecutionException exception) {
for (AtomicArray.Entry<Result> entry : results.asList()) {
results.getSuccessfulResults().forEach((entry) -> {
try {
Transport.Connection connection = nodeIdToConnection.apply(entry.value.shardTarget().getNodeId());
sendReleaseSearchContext(entry.value.id(), connection);
Transport.Connection connection = nodeIdToConnection.apply(entry.shardTarget().getNodeId());
sendReleaseSearchContext(entry.id(), connection);
} catch (Exception inner) {
inner.addSuppressed(exception);
logger.trace("failed to release context", inner);
}
}
});
listener.onFailure(exception);
}

@Override
public final void onShardSuccess(int shardIndex, Result result) {
successfulOps.incrementAndGet();
results.set(shardIndex, result);
results.consumeResult(shardIndex, result);
if (logger.isTraceEnabled()) {
logger.trace("got first-phase result from {}", result != null ? result.shardTarget() : null);
}
Expand All @@ -242,7 +241,7 @@ public final void onPhaseDone() {

@Override
public final int getNumShards() {
return results.length();
return results.getNumShards();
}

@Override
Expand All @@ -262,7 +261,7 @@ public final SearchRequest getRequest() {

@Override
public final SearchResponse buildSearchResponse(InternalSearchResponse internalSearchResponse, String scrollId) {
return new SearchResponse(internalSearchResponse, scrollId, results.length(), successfulOps.get(),
return new SearchResponse(internalSearchResponse, scrollId, getNumShards(), successfulOps.get(),
buildTookInMillis(), buildShardFailures());
}

Expand Down Expand Up @@ -310,6 +309,5 @@ public final ShardSearchTransportRequest buildShardSearchRequest(ShardIterator s
* executed shard request
* @param context the search context for the next phase
*/
protected abstract SearchPhase getNextPhase(AtomicArray<Result> results, SearchPhaseContext context);

protected abstract SearchPhase getNextPhase(SearchPhaseResults<Result> results, SearchPhaseContext context);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package org.elasticsearch.action.search;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
Expand All @@ -30,17 +29,13 @@
* where the given index is used to set the result on the array.
*/
final class CountedCollector<R extends SearchPhaseResult> {
private final AtomicArray<R> resultArray;
private final ResultConsumer<R> resultConsumer;
private final CountDown counter;
private final Runnable onFinish;
private final SearchPhaseContext context;

CountedCollector(AtomicArray<R> resultArray, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
if (expectedOps > resultArray.length()) {
throw new IllegalStateException("unexpected number of operations. got: " + expectedOps + " but array size is: "
+ resultArray.length());
}
this.resultArray = resultArray;
CountedCollector(ResultConsumer<R> resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
this.resultConsumer = resultConsumer;
this.counter = new CountDown(expectedOps);
this.onFinish = onFinish;
this.context = context;
Expand All @@ -63,7 +58,7 @@ void countDown() {
void onResult(int index, R result, SearchShardTarget target) {
try {
result.shardTarget(target);
resultArray.set(index, result);
resultConsumer.consume(index, result);
} finally {
countDown();
}
Expand All @@ -80,4 +75,12 @@ void onFailure(final int shardIndex, @Nullable SearchShardTarget shardTarget, Ex
countDown();
}
}

/**
* A functional interface to plug in shard result consumers to this collector
*/
@FunctionalInterface
public interface ResultConsumer<R extends SearchPhaseResult> {
void consume(int shardIndex, R result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,19 @@
* @see CountedCollector#onFailure(int, SearchShardTarget, Exception)
*/
final class DfsQueryPhase extends SearchPhase {
private final AtomicArray<QuerySearchResultProvider> queryResult;
private final InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider> queryResult;
private final SearchPhaseController searchPhaseController;
private final AtomicArray<DfsSearchResult> dfsSearchResults;
private final Function<AtomicArray<QuerySearchResultProvider>, SearchPhase> nextPhaseFactory;
private final Function<InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider>, SearchPhase> nextPhaseFactory;
private final SearchPhaseContext context;
private final SearchTransportService searchTransportService;

DfsQueryPhase(AtomicArray<DfsSearchResult> dfsSearchResults,
SearchPhaseController searchPhaseController,
Function<AtomicArray<QuerySearchResultProvider>, SearchPhase> nextPhaseFactory, SearchPhaseContext context) {
Function<InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider>, SearchPhase> nextPhaseFactory,
SearchPhaseContext context) {
super("dfs_query");
this.queryResult = new AtomicArray<>(dfsSearchResults.length());
this.queryResult = searchPhaseController.newSearchPhaseResults(context.getRequest(), context.getNumShards());
this.searchPhaseController = searchPhaseController;
this.dfsSearchResults = dfsSearchResults;
this.nextPhaseFactory = nextPhaseFactory;
Expand All @@ -64,7 +65,8 @@ public void run() throws IOException {
// TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs
// to free up memory early
final AggregatedDfs dfs = searchPhaseController.aggregateDfs(dfsSearchResults);
final CountedCollector<QuerySearchResultProvider> counter = new CountedCollector<>(queryResult, dfsSearchResults.asList().size(),
final CountedCollector<QuerySearchResultProvider> counter = new CountedCollector<>(queryResult::consumeResult,
dfsSearchResults.asList().size(),
() -> {
context.executeNextPhase(this, nextPhaseFactory.apply(queryResult));
}, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,29 +49,31 @@ final class FetchSearchPhase extends SearchPhase {
private final Function<SearchResponse, SearchPhase> nextPhaseFactory;
private final SearchPhaseContext context;
private final Logger logger;
private final InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider> resultConsumer;

FetchSearchPhase(AtomicArray<QuerySearchResultProvider> queryResults,
FetchSearchPhase(InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider> resultConsumer,
SearchPhaseController searchPhaseController,
SearchPhaseContext context) {
this(queryResults, searchPhaseController, context,
this(resultConsumer, searchPhaseController, context,
(response) -> new ExpandSearchPhase(context, response, // collapse only happens if the request has inner hits
(finalResponse) -> sendResponsePhase(finalResponse, context)));
}

FetchSearchPhase(AtomicArray<QuerySearchResultProvider> queryResults,
FetchSearchPhase(InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider> resultConsumer,
SearchPhaseController searchPhaseController,
SearchPhaseContext context, Function<SearchResponse, SearchPhase> nextPhaseFactory) {
super("fetch");
if (context.getNumShards() != queryResults.length()) {
if (context.getNumShards() != resultConsumer.getNumShards()) {
throw new IllegalStateException("number of shards must match the length of the query results but doesn't:"
+ context.getNumShards() + "!=" + queryResults.length());
+ context.getNumShards() + "!=" + resultConsumer.getNumShards());
}
this.fetchResults = new AtomicArray<>(queryResults.length());
this.fetchResults = new AtomicArray<>(resultConsumer.getNumShards());
this.searchPhaseController = searchPhaseController;
this.queryResults = queryResults;
this.queryResults = resultConsumer.results;
this.nextPhaseFactory = nextPhaseFactory;
this.context = context;
this.logger = context.getLogger();
this.resultConsumer = resultConsumer;

}

Expand Down Expand Up @@ -99,7 +101,7 @@ private void innerRun() throws IOException {
ScoreDoc[] sortedShardDocs = searchPhaseController.sortDocs(isScrollSearch, queryResults);
String scrollId = isScrollSearch ? TransportSearchHelper.buildScrollId(queryResults) : null;
List<AtomicArray.Entry<QuerySearchResultProvider>> queryResultsAsList = queryResults.asList();
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase(queryResultsAsList);
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = resultConsumer.reduce();
final boolean queryAndFetchOptimization = queryResults.length() == 1;
final Runnable finishPhase = ()
-> moveToNextPhase(searchPhaseController, sortedShardDocs, scrollId, reducedQueryPhase, queryAndFetchOptimization ?
Expand All @@ -119,7 +121,7 @@ private void innerRun() throws IOException {
final ScoreDoc[] lastEmittedDocPerShard = isScrollSearch ?
searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, sortedShardDocs, numShards)
: null;
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(fetchResults,
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(fetchResults::set,
docIdsToLoad.length, // we count down every shard in the result no matter if we got any results or not
finishPhase, context);
for (int i = 0; i < docIdsToLoad.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
import org.elasticsearch.cluster.routing.ShardIterator;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.transport.ConnectTransportException;

import java.io.IOException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;

/**
* This is an abstract base class that encapsulates the logic to fan out to all shards in provided {@link GroupShardsIterator}
Expand Down Expand Up @@ -213,4 +215,53 @@ private void onShardResult(int shardIndex, String nodeId, FirstResult result, Sh
* @param listener the listener to notify on response
*/
protected abstract void executePhaseOnShard(ShardIterator shardIt, ShardRouting shard, ActionListener<FirstResult> listener);

/**
* This class acts as a basic result collection that can be extended to do on-the-fly reduction or result processing
*/
static class SearchPhaseResults<Result extends SearchPhaseResult> {
final AtomicArray<Result> results;

SearchPhaseResults(int size) {
results = new AtomicArray<>(size);
}

/**
* Returns the number of expected results this class should collect
*/
final int getNumShards() {
return results.length();
}

/**
* A stream of all non-null (successful) shard results
*/
final Stream<Result> getSuccessfulResults() {
return results.asList().stream().map(e -> e.value);
}

/**
* Consumes a single shard result
* @param shardIndex the shards index, this is a 0-based id that is used to establish a 1 to 1 mapping to the searched shards
* @param result the shards result
*/
void consumeResult(int shardIndex, Result result) {
assert results.get(shardIndex) == null : "shardIndex: " + shardIndex + " is already set";
results.set(shardIndex, result);
}

/**
* Returns <code>true</code> iff a result if present for the given shard ID.
*/
final boolean hasResult(int shardIndex) {
return results.get(shardIndex) != null;
}

/**
* Reduces the collected results
*/
SearchPhaseController.ReducedQueryPhase reduce() {
throw new UnsupportedOperationException("reduce is not supported");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.cluster.routing.ShardIterator;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.search.dfs.DfsSearchResult;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.transport.Transport;
Expand All @@ -43,7 +42,7 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
ActionListener<SearchResponse> listener, GroupShardsIterator shardsIts, long startTime,
long clusterStateVersion, SearchTask task) {
super("dfs", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, executor,
request, listener, shardsIts, startTime, clusterStateVersion, task);
request, listener, shardsIts, startTime, clusterStateVersion, task, new SearchPhaseResults<>(shardsIts.size()));
this.searchPhaseController = searchPhaseController;
}

Expand All @@ -54,8 +53,8 @@ protected void executePhaseOnShard(ShardIterator shardIt, ShardRouting shard, Ac
}

@Override
protected SearchPhase getNextPhase(AtomicArray<DfsSearchResult> results, SearchPhaseContext context) {
return new DfsQueryPhase(results, searchPhaseController,
protected SearchPhase getNextPhase(SearchPhaseResults<DfsSearchResult> results, SearchPhaseContext context) {
return new DfsQueryPhase(results.results, searchPhaseController,
(queryResults) -> new FetchSearchPhase(queryResults, searchPhaseController, context), context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,5 @@ default void sendReleaseSearchContext(long contextId, Transport.Connection conne
* a response is returned to the user indicating that all shards have failed.
*/
void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase);

}
Loading

0 comments on commit f933f80

Please sign in to comment.