Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First step towards incremental reduction of query responses #23253

Merged
merged 19 commits into from
Feb 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

import java.util.List;
import java.util.Map;
import java.util.StringJoiner;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
Expand All @@ -61,7 +60,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
**/
private final Function<String, Transport.Connection> nodeIdToConnection;
private final SearchTask task;
private final AtomicArray<Result> results;
private final SearchPhaseResults<Result> results;
private final long clusterStateVersion;
private final Map<String, AliasFilter> aliasFilter;
private final Map<String, Float> concreteIndexBoosts;
Expand All @@ -76,7 +75,7 @@ protected AbstractSearchAsyncAction(String name, Logger logger, SearchTransportS
Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts,
Executor executor, SearchRequest request,
ActionListener<SearchResponse> listener, GroupShardsIterator shardsIts, long startTime,
long clusterStateVersion, SearchTask task) {
long clusterStateVersion, SearchTask task, SearchPhaseResults<Result> resultConsumer) {
super(name, request, shardsIts, logger);
this.startTime = startTime;
this.logger = logger;
Expand All @@ -87,9 +86,9 @@ protected AbstractSearchAsyncAction(String name, Logger logger, SearchTransportS
this.listener = listener;
this.nodeIdToConnection = nodeIdToConnection;
this.clusterStateVersion = clusterStateVersion;
results = new AtomicArray<>(shardsIts.size());
this.concreteIndexBoosts = concreteIndexBoosts;
this.aliasFilter = aliasFilter;
this.results = resultConsumer;
}

/**
Expand All @@ -105,7 +104,7 @@ private long buildTookInMillis() {
* This is the main entry point for a search. This method starts the search execution of the initial phase.
*/
public final void start() {
if (results.length() == 0) {
if (getNumShards() == 0) {
//no search shards to search on, bail with empty response
//(it happens with search across _all with no indices around and consistent with broadcast operations)
listener.onResponse(new SearchResponse(InternalSearchResponse.empty(), null, 0, 0, buildTookInMillis(),
Expand All @@ -130,8 +129,8 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha
onPhaseFailure(currentPhase, "all shards failed", null);
} else {
if (logger.isTraceEnabled()) {
final String resultsFrom = results.asList().stream()
.map(r -> r.value.shardTarget().toString()).collect(Collectors.joining(","));
final String resultsFrom = results.getSuccessfulResults()
.map(r -> r.shardTarget().toString()).collect(Collectors.joining(","));
logger.trace("[{}] Moving to next phase: [{}], based on results from: {} (cluster state version: {})",
currentPhase.getName(), nextPhase.getName(), resultsFrom, clusterStateVersion);
}
Expand Down Expand Up @@ -178,7 +177,7 @@ public final void onShardFailure(final int shardIndex, @Nullable SearchShardTarg
synchronized (shardFailuresMutex) {
shardFailures = this.shardFailures.get(); // read again otherwise somebody else has created it?
if (shardFailures == null) { // still null so we are the first and create a new instance
shardFailures = new AtomicArray<>(results.length());
shardFailures = new AtomicArray<>(getNumShards());
this.shardFailures.set(shardFailures);
}
}
Expand All @@ -194,7 +193,7 @@ public final void onShardFailure(final int shardIndex, @Nullable SearchShardTarg
}
}

if (results.get(shardIndex) != null) {
if (results.hasResult(shardIndex)) {
assert failure == null : "shard failed before but shouldn't: " + failure;
successfulOps.decrementAndGet(); // if this shard was successful before (initial phase) we have to adjust the counter
}
Expand All @@ -207,22 +206,22 @@ public final void onShardFailure(final int shardIndex, @Nullable SearchShardTarg
* @param exception the exception explaining or causing the phase failure
*/
private void raisePhaseFailure(SearchPhaseExecutionException exception) {
for (AtomicArray.Entry<Result> entry : results.asList()) {
results.getSuccessfulResults().forEach((entry) -> {
try {
Transport.Connection connection = nodeIdToConnection.apply(entry.value.shardTarget().getNodeId());
sendReleaseSearchContext(entry.value.id(), connection);
Transport.Connection connection = nodeIdToConnection.apply(entry.shardTarget().getNodeId());
sendReleaseSearchContext(entry.id(), connection);
} catch (Exception inner) {
inner.addSuppressed(exception);
logger.trace("failed to release context", inner);
}
}
});
listener.onFailure(exception);
}

@Override
public final void onShardSuccess(int shardIndex, Result result) {
successfulOps.incrementAndGet();
results.set(shardIndex, result);
results.consumeResult(shardIndex, result);
if (logger.isTraceEnabled()) {
logger.trace("got first-phase result from {}", result != null ? result.shardTarget() : null);
}
Expand All @@ -242,7 +241,7 @@ public final void onPhaseDone() {

@Override
public final int getNumShards() {
return results.length();
return results.getNumShards();
}

@Override
Expand All @@ -262,7 +261,7 @@ public final SearchRequest getRequest() {

@Override
public final SearchResponse buildSearchResponse(InternalSearchResponse internalSearchResponse, String scrollId) {
return new SearchResponse(internalSearchResponse, scrollId, results.length(), successfulOps.get(),
return new SearchResponse(internalSearchResponse, scrollId, getNumShards(), successfulOps.get(),
buildTookInMillis(), buildShardFailures());
}

Expand Down Expand Up @@ -310,6 +309,5 @@ public final ShardSearchTransportRequest buildShardSearchRequest(ShardIterator s
* executed shard request
* @param context the search context for the next phase
*/
protected abstract SearchPhase getNextPhase(AtomicArray<Result> results, SearchPhaseContext context);

protected abstract SearchPhase getNextPhase(SearchPhaseResults<Result> results, SearchPhaseContext context);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package org.elasticsearch.action.search;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
Expand All @@ -30,17 +29,13 @@
* where the given index is used to set the result on the array.
*/
final class CountedCollector<R extends SearchPhaseResult> {
private final AtomicArray<R> resultArray;
private final ResultConsumer<R> resultConsumer;
private final CountDown counter;
private final Runnable onFinish;
private final SearchPhaseContext context;

CountedCollector(AtomicArray<R> resultArray, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
if (expectedOps > resultArray.length()) {
throw new IllegalStateException("unexpected number of operations. got: " + expectedOps + " but array size is: "
+ resultArray.length());
}
this.resultArray = resultArray;
CountedCollector(ResultConsumer<R> resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
this.resultConsumer = resultConsumer;
this.counter = new CountDown(expectedOps);
this.onFinish = onFinish;
this.context = context;
Expand All @@ -63,7 +58,7 @@ void countDown() {
void onResult(int index, R result, SearchShardTarget target) {
try {
result.shardTarget(target);
resultArray.set(index, result);
resultConsumer.consume(index, result);
} finally {
countDown();
}
Expand All @@ -80,4 +75,12 @@ void onFailure(final int shardIndex, @Nullable SearchShardTarget shardTarget, Ex
countDown();
}
}

/**
* A functional interface to plug in shard result consumers to this collector
*/
@FunctionalInterface
public interface ResultConsumer<R extends SearchPhaseResult> {
void consume(int shardIndex, R result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,19 @@
* @see CountedCollector#onFailure(int, SearchShardTarget, Exception)
*/
final class DfsQueryPhase extends SearchPhase {
private final AtomicArray<QuerySearchResultProvider> queryResult;
private final InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider> queryResult;
private final SearchPhaseController searchPhaseController;
private final AtomicArray<DfsSearchResult> dfsSearchResults;
private final Function<AtomicArray<QuerySearchResultProvider>, SearchPhase> nextPhaseFactory;
private final Function<InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider>, SearchPhase> nextPhaseFactory;
private final SearchPhaseContext context;
private final SearchTransportService searchTransportService;

DfsQueryPhase(AtomicArray<DfsSearchResult> dfsSearchResults,
SearchPhaseController searchPhaseController,
Function<AtomicArray<QuerySearchResultProvider>, SearchPhase> nextPhaseFactory, SearchPhaseContext context) {
Function<InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider>, SearchPhase> nextPhaseFactory,
SearchPhaseContext context) {
super("dfs_query");
this.queryResult = new AtomicArray<>(dfsSearchResults.length());
this.queryResult = searchPhaseController.newSearchPhaseResults(context.getRequest(), context.getNumShards());
this.searchPhaseController = searchPhaseController;
this.dfsSearchResults = dfsSearchResults;
this.nextPhaseFactory = nextPhaseFactory;
Expand All @@ -64,7 +65,8 @@ public void run() throws IOException {
// TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs
// to free up memory early
final AggregatedDfs dfs = searchPhaseController.aggregateDfs(dfsSearchResults);
final CountedCollector<QuerySearchResultProvider> counter = new CountedCollector<>(queryResult, dfsSearchResults.asList().size(),
final CountedCollector<QuerySearchResultProvider> counter = new CountedCollector<>(queryResult::consumeResult,
dfsSearchResults.asList().size(),
() -> {
context.executeNextPhase(this, nextPhaseFactory.apply(queryResult));
}, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,29 +49,31 @@ final class FetchSearchPhase extends SearchPhase {
private final Function<SearchResponse, SearchPhase> nextPhaseFactory;
private final SearchPhaseContext context;
private final Logger logger;
private final InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider> resultConsumer;

FetchSearchPhase(AtomicArray<QuerySearchResultProvider> queryResults,
FetchSearchPhase(InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider> resultConsumer,
SearchPhaseController searchPhaseController,
SearchPhaseContext context) {
this(queryResults, searchPhaseController, context,
this(resultConsumer, searchPhaseController, context,
(response) -> new ExpandSearchPhase(context, response, // collapse only happens if the request has inner hits
(finalResponse) -> sendResponsePhase(finalResponse, context)));
}

FetchSearchPhase(AtomicArray<QuerySearchResultProvider> queryResults,
FetchSearchPhase(InitialSearchPhase.SearchPhaseResults<QuerySearchResultProvider> resultConsumer,
SearchPhaseController searchPhaseController,
SearchPhaseContext context, Function<SearchResponse, SearchPhase> nextPhaseFactory) {
super("fetch");
if (context.getNumShards() != queryResults.length()) {
if (context.getNumShards() != resultConsumer.getNumShards()) {
throw new IllegalStateException("number of shards must match the length of the query results but doesn't:"
+ context.getNumShards() + "!=" + queryResults.length());
+ context.getNumShards() + "!=" + resultConsumer.getNumShards());
}
this.fetchResults = new AtomicArray<>(queryResults.length());
this.fetchResults = new AtomicArray<>(resultConsumer.getNumShards());
this.searchPhaseController = searchPhaseController;
this.queryResults = queryResults;
this.queryResults = resultConsumer.results;
this.nextPhaseFactory = nextPhaseFactory;
this.context = context;
this.logger = context.getLogger();
this.resultConsumer = resultConsumer;

}

Expand Down Expand Up @@ -99,7 +101,7 @@ private void innerRun() throws IOException {
ScoreDoc[] sortedShardDocs = searchPhaseController.sortDocs(isScrollSearch, queryResults);
String scrollId = isScrollSearch ? TransportSearchHelper.buildScrollId(queryResults) : null;
List<AtomicArray.Entry<QuerySearchResultProvider>> queryResultsAsList = queryResults.asList();
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase(queryResultsAsList);
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = resultConsumer.reduce();
final boolean queryAndFetchOptimization = queryResults.length() == 1;
final Runnable finishPhase = ()
-> moveToNextPhase(searchPhaseController, sortedShardDocs, scrollId, reducedQueryPhase, queryAndFetchOptimization ?
Expand All @@ -119,7 +121,7 @@ private void innerRun() throws IOException {
final ScoreDoc[] lastEmittedDocPerShard = isScrollSearch ?
searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, sortedShardDocs, numShards)
: null;
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(fetchResults,
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(fetchResults::set,
docIdsToLoad.length, // we count down every shard in the result no matter if we got any results or not
finishPhase, context);
for (int i = 0; i < docIdsToLoad.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
import org.elasticsearch.cluster.routing.ShardIterator;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.transport.ConnectTransportException;

import java.io.IOException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;

/**
* This is an abstract base class that encapsulates the logic to fan out to all shards in provided {@link GroupShardsIterator}
Expand Down Expand Up @@ -213,4 +215,53 @@ private void onShardResult(int shardIndex, String nodeId, FirstResult result, Sh
* @param listener the listener to notify on response
*/
protected abstract void executePhaseOnShard(ShardIterator shardIt, ShardRouting shard, ActionListener<FirstResult> listener);

/**
* This class acts as a basic result collection that can be extended to do on-the-fly reduction or result processing
*/
static class SearchPhaseResults<Result extends SearchPhaseResult> {
final AtomicArray<Result> results;

SearchPhaseResults(int size) {
results = new AtomicArray<>(size);
}

/**
* Returns the number of expected results this class should collect
*/
final int getNumShards() {
return results.length();
}

/**
* A stream of all non-null (successful) shard results
*/
final Stream<Result> getSuccessfulResults() {
return results.asList().stream().map(e -> e.value);
}

/**
* Consumes a single shard result
* @param shardIndex the shards index, this is a 0-based id that is used to establish a 1 to 1 mapping to the searched shards
* @param result the shards result
*/
void consumeResult(int shardIndex, Result result) {
assert results.get(shardIndex) == null : "shardIndex: " + shardIndex + " is already set";
results.set(shardIndex, result);
}

/**
* Returns <code>true</code> iff a result if present for the given shard ID.
*/
final boolean hasResult(int shardIndex) {
return results.get(shardIndex) != null;
}

/**
* Reduces the collected results
*/
SearchPhaseController.ReducedQueryPhase reduce() {
throw new UnsupportedOperationException("reduce is not supported");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.cluster.routing.ShardIterator;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.search.dfs.DfsSearchResult;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.transport.Transport;
Expand All @@ -43,7 +42,7 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
ActionListener<SearchResponse> listener, GroupShardsIterator shardsIts, long startTime,
long clusterStateVersion, SearchTask task) {
super("dfs", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, executor,
request, listener, shardsIts, startTime, clusterStateVersion, task);
request, listener, shardsIts, startTime, clusterStateVersion, task, new SearchPhaseResults<>(shardsIts.size()));
this.searchPhaseController = searchPhaseController;
}

Expand All @@ -54,8 +53,8 @@ protected void executePhaseOnShard(ShardIterator shardIt, ShardRouting shard, Ac
}

@Override
protected SearchPhase getNextPhase(AtomicArray<DfsSearchResult> results, SearchPhaseContext context) {
return new DfsQueryPhase(results, searchPhaseController,
protected SearchPhase getNextPhase(SearchPhaseResults<DfsSearchResult> results, SearchPhaseContext context) {
return new DfsQueryPhase(results.results, searchPhaseController,
(queryResults) -> new FetchSearchPhase(queryResults, searchPhaseController, context), context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,5 @@ default void sendReleaseSearchContext(long contextId, Transport.Connection conne
* a response is returned to the user indicating that all shards have failed.
*/
void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase);

}
Loading