diff --git a/docs/reference/search/profile.asciidoc b/docs/reference/search/profile.asciidoc index b244453515378..e2df59ad3f4a3 100644 --- a/docs/reference/search/profile.asciidoc +++ b/docs/reference/search/profile.asciidoc @@ -596,7 +596,7 @@ And the response: ] }, { - "name": "BucketCollector: [[my_scoped_agg, my_global_agg]]", + "name": "MultiBucketCollector: [[my_scoped_agg, my_global_agg]]", "reason": "aggregation", "time_in_nanos": 8273 } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java b/server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java index 4dc765d0db1de..75ef3c9199a25 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java @@ -60,7 +60,7 @@ public void preProcess(SearchContext context) { } context.aggregations().aggregators(aggregators); if (!collectors.isEmpty()) { - Collector collector = BucketCollector.wrap(collectors); + Collector collector = MultiBucketCollector.wrap(collectors); ((BucketCollector)collector).preCollection(); if (context.getProfilers() != null) { collector = new InternalProfileCollector(collector, CollectorResult.REASON_AGGREGATION, @@ -97,7 +97,7 @@ public void execute(SearchContext context) { // optimize the global collector based execution if (!globals.isEmpty()) { - BucketCollector globalsCollector = BucketCollector.wrap(globals); + BucketCollector globalsCollector = MultiBucketCollector.wrap(globals); Query query = context.buildFilteredQuery(Queries.newMatchAllQuery()); try { diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java b/server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java index 1e2e7332ab7b4..2ad76d8a2b49c 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java @@ -183,7 +183,7 @@ protected void doPreCollection() throws IOException { @Override public final void preCollection() throws IOException { List collectors = Arrays.asList(subAggregators); - collectableSubAggregators = BucketCollector.wrap(collectors); + collectableSubAggregators = MultiBucketCollector.wrap(collectors); doPreCollection(); collectableSubAggregators.preCollection(); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/BucketCollector.java b/server/src/main/java/org/elasticsearch/search/aggregations/BucketCollector.java index 40e66bd964539..f2c8bf5e16e44 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/BucketCollector.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/BucketCollector.java @@ -24,10 +24,6 @@ import org.apache.lucene.search.Collector; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.stream.StreamSupport; /** * A Collector that can collect data in separate buckets. @@ -54,61 +50,6 @@ public boolean needsScores() { } }; - /** - * Wrap the given collectors into a single instance. - */ - public static BucketCollector wrap(Iterable collectorList) { - final BucketCollector[] collectors = - StreamSupport.stream(collectorList.spliterator(), false).toArray(size -> new BucketCollector[size]); - switch (collectors.length) { - case 0: - return NO_OP_COLLECTOR; - case 1: - return collectors[0]; - default: - return new BucketCollector() { - - @Override - public LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException { - List leafCollectors = new ArrayList<>(collectors.length); - for (BucketCollector c : collectors) { - leafCollectors.add(c.getLeafCollector(ctx)); - } - return LeafBucketCollector.wrap(leafCollectors); - } - - @Override - public void preCollection() throws IOException { - for (BucketCollector collector : collectors) { - collector.preCollection(); - } - } - - @Override - public void postCollection() throws IOException { - for (BucketCollector collector : collectors) { - collector.postCollection(); - } - } - - @Override - public boolean needsScores() { - for (BucketCollector collector : collectors) { - if (collector.needsScores()) { - return true; - } - } - return false; - } - - @Override - public String toString() { - return Arrays.toString(collectors); - } - }; - } - } - @Override public abstract LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException; diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/MultiBucketCollector.java b/server/src/main/java/org/elasticsearch/search/aggregations/MultiBucketCollector.java new file mode 100644 index 0000000000000..a8a015ab5453b --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/MultiBucketCollector.java @@ -0,0 +1,207 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.MultiCollector; +import org.apache.lucene.search.ScoreCachingWrappingScorer; +import org.apache.lucene.search.Scorer; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * A {@link BucketCollector} which allows running a bucket collection with several + * {@link BucketCollector}s. It is similar to the {@link MultiCollector} except that the + * {@link #wrap} method filters out the {@link BucketCollector#NO_OP_COLLECTOR}s and not + * the null ones. + */ +public class MultiBucketCollector extends BucketCollector { + + /** See {@link #wrap(Iterable)}. */ + public static BucketCollector wrap(BucketCollector... collectors) { + return wrap(Arrays.asList(collectors)); + } + + /** + * Wraps a list of {@link BucketCollector}s with a {@link MultiBucketCollector}. This + * method works as follows: + * + */ + public static BucketCollector wrap(Iterable collectors) { + // For the user's convenience, we allow NO_OP collectors to be passed. + // However, to improve performance, these null collectors are found + // and dropped from the array we save for actual collection time. + int n = 0; + for (BucketCollector c : collectors) { + if (c != NO_OP_COLLECTOR) { + n++; + } + } + + if (n == 0) { + return NO_OP_COLLECTOR; + } else if (n == 1) { + // only 1 Collector - return it. + BucketCollector col = null; + for (BucketCollector c : collectors) { + if (c != null) { + col = c; + break; + } + } + return col; + } else { + BucketCollector[] colls = new BucketCollector[n]; + n = 0; + for (BucketCollector c : collectors) { + if (c != null) { + colls[n++] = c; + } + } + return new MultiBucketCollector(colls); + } + } + + private final boolean cacheScores; + private final BucketCollector[] collectors; + + private MultiBucketCollector(BucketCollector... collectors) { + this.collectors = collectors; + int numNeedsScores = 0; + for (BucketCollector collector : collectors) { + if (collector.needsScores()) { + numNeedsScores += 1; + } + } + this.cacheScores = numNeedsScores >= 2; + } + + @Override + public void preCollection() throws IOException { + for (BucketCollector collector : collectors) { + collector.preCollection(); + } + } + + @Override + public void postCollection() throws IOException { + for (BucketCollector collector : collectors) { + collector.postCollection(); + } + } + + @Override + public boolean needsScores() { + for (BucketCollector collector : collectors) { + if (collector.needsScores()) { + return true; + } + } + return false; + } + + @Override + public String toString() { + return Arrays.toString(collectors); + } + + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext context) throws IOException { + final List leafCollectors = new ArrayList<>(); + for (BucketCollector collector : collectors) { + final LeafBucketCollector leafCollector; + try { + leafCollector = collector.getLeafCollector(context); + } catch (CollectionTerminatedException e) { + // this leaf collector does not need this segment + continue; + } + leafCollectors.add(leafCollector); + } + switch (leafCollectors.size()) { + case 0: + throw new CollectionTerminatedException(); + case 1: + return leafCollectors.get(0); + default: + return new MultiLeafBucketCollector(leafCollectors, cacheScores); + } + } + + private static class MultiLeafBucketCollector extends LeafBucketCollector { + + private final boolean cacheScores; + private final LeafBucketCollector[] collectors; + private int numCollectors; + + private MultiLeafBucketCollector(List collectors, boolean cacheScores) { + this.collectors = collectors.toArray(new LeafBucketCollector[collectors.size()]); + this.cacheScores = cacheScores; + this.numCollectors = this.collectors.length; + } + + @Override + public void setScorer(Scorer scorer) throws IOException { + if (cacheScores) { + scorer = new ScoreCachingWrappingScorer(scorer); + } + for (int i = 0; i < numCollectors; ++i) { + final LeafCollector c = collectors[i]; + c.setScorer(scorer); + } + } + + private void removeCollector(int i) { + System.arraycopy(collectors, i + 1, collectors, i, numCollectors - i - 1); + --numCollectors; + collectors[numCollectors] = null; + } + + @Override + public void collect(int doc, long bucket) throws IOException { + final LeafBucketCollector[] collectors = this.collectors; + int numCollectors = this.numCollectors; + for (int i = 0; i < numCollectors; ) { + final LeafBucketCollector collector = collectors[i]; + try { + collector.collect(doc, bucket); + ++i; + } catch (CollectionTerminatedException e) { + removeCollector(i); + numCollectors = this.numCollectors; + if (numCollectors == 0) { + throw new CollectionTerminatedException(); + } + } + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/BestBucketsDeferringCollector.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/BestBucketsDeferringCollector.java index d6be0f5786644..6ebf9e3c41c40 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/BestBucketsDeferringCollector.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/BestBucketsDeferringCollector.java @@ -33,6 +33,7 @@ import org.elasticsearch.search.aggregations.BucketCollector; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.LeafBucketCollector; +import org.elasticsearch.search.aggregations.MultiBucketCollector; import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; @@ -90,7 +91,7 @@ public boolean needsScores() { /** Set the deferred collectors. */ @Override public void setDeferredCollector(Iterable deferredCollectors) { - this.collector = BucketCollector.wrap(deferredCollectors); + this.collector = MultiBucketCollector.wrap(deferredCollectors); } private void finishLeaf() { diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/DeferableBucketAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/DeferableBucketAggregator.java index 0ff5ea12b97be..b4e2243f17a76 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/DeferableBucketAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/DeferableBucketAggregator.java @@ -22,6 +22,7 @@ import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.BucketCollector; +import org.elasticsearch.search.aggregations.MultiBucketCollector; import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregator; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.internal.SearchContext; @@ -59,7 +60,7 @@ protected void doPreCollection() throws IOException { recordingWrapper.setDeferredCollector(deferredCollectors); collectors.add(recordingWrapper); } - collectableSubAggregators = BucketCollector.wrap(collectors); + collectableSubAggregators = MultiBucketCollector.wrap(collectors); } public static boolean descendsFromGlobalAggregator(Aggregator parent) { diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/MergingBucketsDeferringCollector.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/MergingBucketsDeferringCollector.java index f357e9d286f54..5653bc58f2a6c 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/MergingBucketsDeferringCollector.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/MergingBucketsDeferringCollector.java @@ -31,6 +31,7 @@ import org.elasticsearch.search.aggregations.BucketCollector; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.LeafBucketCollector; +import org.elasticsearch.search.aggregations.MultiBucketCollector; import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; @@ -61,7 +62,7 @@ public MergingBucketsDeferringCollector(SearchContext context) { @Override public void setDeferredCollector(Iterable deferredCollectors) { - this.collector = BucketCollector.wrap(deferredCollectors); + this.collector = MultiBucketCollector.wrap(deferredCollectors); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregator.java index 9df33691c7faa..b02f06b8cf46c 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregator.java @@ -38,6 +38,7 @@ import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.LeafBucketCollector; +import org.elasticsearch.search.aggregations.MultiBucketCollector; import org.elasticsearch.search.aggregations.bucket.BucketsAggregator; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.support.ValuesSource; @@ -93,7 +94,7 @@ protected void doClose() { @Override protected void doPreCollection() throws IOException { List collectors = Arrays.asList(subAggregators); - deferredCollectors = BucketCollector.wrap(collectors); + deferredCollectors = MultiBucketCollector.wrap(collectors); collectableSubAggregators = BucketCollector.NO_OP_COLLECTOR; } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/BestDocsDeferringCollector.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/BestDocsDeferringCollector.java index 05d9402230d03..bb89173e76791 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/BestDocsDeferringCollector.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/BestDocsDeferringCollector.java @@ -33,6 +33,7 @@ import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.search.aggregations.BucketCollector; import org.elasticsearch.search.aggregations.LeafBucketCollector; +import org.elasticsearch.search.aggregations.MultiBucketCollector; import org.elasticsearch.search.aggregations.bucket.DeferringBucketCollector; import java.io.IOException; @@ -76,7 +77,7 @@ public boolean needsScores() { /** Set the deferred collectors. */ @Override public void setDeferredCollector(Iterable deferredCollectors) { - this.deferred = BucketCollector.wrap(deferredCollectors); + this.deferred = MultiBucketCollector.wrap(deferredCollectors); } @Override diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/MultiBucketCollectorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/MultiBucketCollectorTests.java new file mode 100644 index 0000000000000..f9abdeed50f82 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/aggregations/MultiBucketCollectorTests.java @@ -0,0 +1,262 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations; + +import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +public class MultiBucketCollectorTests extends ESTestCase { + private static class FakeScorer extends Scorer { + float score; + int doc = -1; + + FakeScorer() { + super(null); + } + + @Override + public int docID() { + return doc; + } + + @Override + public float score() { + return score; + } + + @Override + public DocIdSetIterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public Weight getWeight() { + throw new UnsupportedOperationException(); + } + + @Override + public Collection getChildren() { + throw new UnsupportedOperationException(); + } + } + + private static class TerminateAfterBucketCollector extends BucketCollector { + + private int count = 0; + private final int terminateAfter; + private final BucketCollector in; + + TerminateAfterBucketCollector(BucketCollector in, int terminateAfter) { + this.in = in; + this.terminateAfter = terminateAfter; + } + + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext context) throws IOException { + if (count >= terminateAfter) { + throw new CollectionTerminatedException(); + } + final LeafBucketCollector leafCollector = in.getLeafCollector(context); + return new LeafBucketCollectorBase(leafCollector, null) { + @Override + public void collect(int doc, long bucket) throws IOException { + if (count >= terminateAfter) { + throw new CollectionTerminatedException(); + } + super.collect(doc, bucket); + count++; + } + }; + } + + @Override + public boolean needsScores() { + return false; + } + + @Override + public void preCollection() {} + + @Override + public void postCollection() {} + } + + private static class TotalHitCountBucketCollector extends BucketCollector { + + private int count = 0; + + TotalHitCountBucketCollector() { + } + + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext context) { + return new LeafBucketCollector() { + @Override + public void collect(int doc, long bucket) throws IOException { + count++; + } + }; + } + + @Override + public boolean needsScores() { + return false; + } + + @Override + public void preCollection() {} + + @Override + public void postCollection() {} + + int getTotalHits() { + return count; + } + } + + private static class SetScorerBucketCollector extends BucketCollector { + private final BucketCollector in; + private final AtomicBoolean setScorerCalled; + + SetScorerBucketCollector(BucketCollector in, AtomicBoolean setScorerCalled) { + this.in = in; + this.setScorerCalled = setScorerCalled; + } + + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext context) throws IOException { + final LeafBucketCollector leafCollector = in.getLeafCollector(context); + return new LeafBucketCollectorBase(leafCollector, null) { + @Override + public void setScorer(Scorer scorer) throws IOException { + super.setScorer(scorer); + setScorerCalled.set(true); + } + }; + } + + @Override + public boolean needsScores() { + return false; + } + + @Override + public void preCollection() {} + + @Override + public void postCollection() {} + } + + public void testCollectionTerminatedExceptionHandling() throws IOException { + final int iters = atLeast(3); + for (int iter = 0; iter < iters; ++iter) { + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir); + final int numDocs = randomIntBetween(100, 1000); + final Document doc = new Document(); + for (int i = 0; i < numDocs; ++i) { + w.addDocument(doc); + } + final IndexReader reader = w.getReader(); + w.close(); + final IndexSearcher searcher = newSearcher(reader); + Map expectedCounts = new HashMap<>(); + List collectors = new ArrayList<>(); + final int numCollectors = randomIntBetween(1, 5); + for (int i = 0; i < numCollectors; ++i) { + final int terminateAfter = random().nextInt(numDocs + 10); + final int expectedCount = terminateAfter > numDocs ? numDocs : terminateAfter; + TotalHitCountBucketCollector collector = new TotalHitCountBucketCollector(); + expectedCounts.put(collector, expectedCount); + collectors.add(new TerminateAfterBucketCollector(collector, terminateAfter)); + } + searcher.search(new MatchAllDocsQuery(), MultiBucketCollector.wrap(collectors)); + for (Map.Entry expectedCount : expectedCounts.entrySet()) { + assertEquals(expectedCount.getValue().intValue(), expectedCount.getKey().getTotalHits()); + } + reader.close(); + dir.close(); + } + } + + public void testSetScorerAfterCollectionTerminated() throws IOException { + BucketCollector collector1 = new TotalHitCountBucketCollector(); + BucketCollector collector2 = new TotalHitCountBucketCollector(); + + AtomicBoolean setScorerCalled1 = new AtomicBoolean(); + collector1 = new SetScorerBucketCollector(collector1, setScorerCalled1); + + AtomicBoolean setScorerCalled2 = new AtomicBoolean(); + collector2 = new SetScorerBucketCollector(collector2, setScorerCalled2); + + collector1 = new TerminateAfterBucketCollector(collector1, 1); + collector2 = new TerminateAfterBucketCollector(collector2, 2); + + Scorer scorer = new FakeScorer(); + + List collectors = Arrays.asList(collector1, collector2); + Collections.shuffle(collectors, random()); + BucketCollector collector = MultiBucketCollector.wrap(collectors); + + LeafBucketCollector leafCollector = collector.getLeafCollector(null); + leafCollector.setScorer(scorer); + assertTrue(setScorerCalled1.get()); + assertTrue(setScorerCalled2.get()); + + leafCollector.collect(0); + leafCollector.collect(1); + + setScorerCalled1.set(false); + setScorerCalled2.set(false); + leafCollector.setScorer(scorer); + assertFalse(setScorerCalled1.get()); + assertTrue(setScorerCalled2.get()); + + expectThrows(CollectionTerminatedException.class, () -> { + leafCollector.collect(1); + }); + + setScorerCalled1.set(false); + setScorerCalled2.set(false); + leafCollector.setScorer(scorer); + assertFalse(setScorerCalled1.get()); + assertFalse(setScorerCalled2.get()); + } +}