From 22ca7ca708b2b6b4c86de42482c4fce8c20e19ee Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 8 Jun 2017 22:39:25 +0200 Subject: [PATCH] Speed up sorted scroll when the index sort matches the search sort Sorted scroll search can use early termination when the index sort matches the scroll search sort. The optimization can be done after the first query (which still needs to collect all documents) by applying a query that only matches documents that are greater than the last doc retrieved in the previous request. Since the index is sorted, retrieving the list of documents that are greater than the last doc only requires a binary search on each segment. This change introduces this new query called `SortedSearchAfterDocQuery` and apply it when possible. Scrolls with this optimization will search all documents on the first request and then will early terminate each segment after $size doc for any subsequent requests. Relates #6720 --- .../apache/lucene/queries/MinDocQuery.java | 78 +++++---- .../queries/SearchAfterSortedDocQuery.java | 160 ++++++++++++++++++ .../search/query/QueryPhase.java | 36 ++-- .../SearchAfterSortedDocQueryTests.java | 130 ++++++++++++++ .../search/query/QueryPhaseTests.java | 73 +++++++- 5 files changed, 427 insertions(+), 50 deletions(-) create mode 100644 core/src/main/java/org/apache/lucene/queries/SearchAfterSortedDocQuery.java create mode 100644 core/src/test/java/org/apache/lucene/queries/SearchAfterSortedDocQueryTests.java diff --git a/core/src/main/java/org/apache/lucene/queries/MinDocQuery.java b/core/src/main/java/org/apache/lucene/queries/MinDocQuery.java index d4f9ab729736c..65c5c0f707c5b 100644 --- a/core/src/main/java/org/apache/lucene/queries/MinDocQuery.java +++ b/core/src/main/java/org/apache/lucene/queries/MinDocQuery.java @@ -66,46 +66,54 @@ public Scorer scorer(LeafReaderContext context) throws IOException { return null; } final int segmentMinDoc = Math.max(0, minDoc - context.docBase); - final DocIdSetIterator disi = new DocIdSetIterator() { - - int doc = -1; - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - assert target > doc; - if (doc == -1) { - // skip directly to minDoc - doc = Math.max(target, segmentMinDoc); - } else { - doc = target; - } - if (doc >= maxDoc) { - doc = NO_MORE_DOCS; - } - return doc; - } - - @Override - public long cost() { - return maxDoc - segmentMinDoc; - } - - }; + final DocIdSetIterator disi = new MinDocIterator(segmentMinDoc, maxDoc); return new ConstantScoreScorer(this, score(), disi); } }; } + static class MinDocIterator extends DocIdSetIterator { + final int segmentMinDoc; + final int maxDoc; + int doc = -1; + + MinDocIterator(int segmentMinDoc, int maxDoc) { + this.segmentMinDoc = segmentMinDoc; + this.maxDoc = maxDoc; + } + + @Override + public int docID() { + return doc; + } + + @Override + public int nextDoc() throws IOException { + return advance(doc + 1); + } + + @Override + public int advance(int target) throws IOException { + assert target > doc; + if (doc == -1) { + // skip directly to minDoc + doc = Math.max(target, segmentMinDoc); + } else { + doc = target; + } + if (doc >= maxDoc) { + doc = NO_MORE_DOCS; + } + return doc; + } + + @Override + public long cost() { + return maxDoc - segmentMinDoc; + } + } + + @Override public String toString(String field) { return "MinDocQuery(minDoc=" + minDoc + ")"; diff --git a/core/src/main/java/org/apache/lucene/queries/SearchAfterSortedDocQuery.java b/core/src/main/java/org/apache/lucene/queries/SearchAfterSortedDocQuery.java new file mode 100644 index 0000000000000..4df045ad0bfca --- /dev/null +++ b/core/src/main/java/org/apache/lucene/queries/SearchAfterSortedDocQuery.java @@ -0,0 +1,160 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.lucene.queries; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.ConstantScoreScorer; +import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.EarlyTerminatingSortingCollector; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafFieldComparator; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.Weight; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/** + * A {@link Query} that only matches documents that are greater than the provided {@link FieldDoc}. + * This works only if the index is sorted according to the given search {@link Sort}. + */ +public class SearchAfterSortedDocQuery extends Query { + private final Sort sort; + private final FieldDoc after; + private final List> fieldComparators; + + public SearchAfterSortedDocQuery(Sort sort, FieldDoc after) { + if (sort.getSort().length != after.fields.length) { + throw new IllegalArgumentException("after doc has " + after.fields.length + " value(s) but sort has " + + sort.getSort().length + "."); + } + this.sort = sort; + this.after = after; + this.fieldComparators = new ArrayList<>(); + for (int i = 0; i < sort.getSort().length; i++) { + FieldComparator fieldComparator = sort.getSort()[i].getComparator(1, i); + @SuppressWarnings("unchecked") + FieldComparator comparator = (FieldComparator) fieldComparator; + comparator.setTopValue(after.fields[i]); + fieldComparators.add(fieldComparator); + } + } + + @Override + public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException { + return new ConstantScoreWeight(this, 1.0f) { + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + Sort segmentSort = context.reader().getMetaData().getSort(); + if (EarlyTerminatingSortingCollector.canEarlyTerminate(sort, segmentSort) == false) { + throw new IOException("wrong sort"); + } + TopComparator comparator= getTopComparator(fieldComparators, context, after.doc); + final int maxDoc = context.reader().maxDoc(); + final int firstDoc = searchAfterDoc(comparator, 0, context.reader().maxDoc()); + if (firstDoc >= maxDoc) { + return null; + } + final DocIdSetIterator disi = new MinDocQuery.MinDocIterator(firstDoc, maxDoc); + return new ConstantScoreScorer(this, score(), disi); + } + }; + } + + @Override + public String toString(String field) { + return "SearchAfterSortedDocQuery(sort=" + sort + ", afterDoc=" + after.toString() + ")"; + } + + @Override + public boolean equals(Object other) { + return sameClassAs(other) && + equalsTo(getClass().cast(other)); + } + + private boolean equalsTo(SearchAfterSortedDocQuery other) { + return sort.equals(other.sort) && + after.doc == other.after.doc && + Double.compare(after.score, other.after.score) == 0 && + Arrays.equals(after.fields, other.after.fields); + } + + @Override + public int hashCode() { + return Objects.hash(classHash(), sort, after.doc, after.score, Arrays.hashCode(after.fields)); + } + + interface TopComparator { + boolean lessThanTop(int doc) throws IOException; + } + + static TopComparator getTopComparator(List> fieldComparators, LeafReaderContext leafReaderContext, int topDoc) { + return doc -> { + // DVs use forward iterators so we recreate the iterator for each sort field + // every time we need to compare a document with the after doc. + // We could reuse the iterators when the comparison goes forward but + // this should only be called a few time per segment (binary search). + for (int i = 0; i < fieldComparators.size(); i++) { + LeafFieldComparator comparator = fieldComparators.get(i).getLeafComparator(leafReaderContext); + int value = comparator.compareTop(doc); + if (value != 0) { + return value < 0; + } + } + if (topDoc < leafReaderContext.docBase) { + return false; + } else { + if (topDoc < leafReaderContext.docBase + leafReaderContext.reader().maxDoc()) { + if (topDoc <= doc+leafReaderContext.docBase) { + return false; + } + } + return true; + } + }; + } + + /** + * Returns the first doc id greater than the provided after doc. + */ + static int searchAfterDoc(TopComparator comparator, int from, int to) throws IOException { + int low = from; + int high = to - 1; + + while (low <= high) { + int mid = (low + high) >>> 1; + if (comparator.lessThanTop(mid)) { + high = mid - 1; + } else { + low = mid + 1; + } + } + return low; + } + +} diff --git a/core/src/main/java/org/elasticsearch/search/query/QueryPhase.java b/core/src/main/java/org/elasticsearch/search/query/QueryPhase.java index 10c180f687ee6..82e572a180e1d 100644 --- a/core/src/main/java/org/elasticsearch/search/query/QueryPhase.java +++ b/core/src/main/java/org/elasticsearch/search/query/QueryPhase.java @@ -21,11 +21,13 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.queries.MinDocQuery; +import org.apache.lucene.queries.SearchAfterSortedDocQuery; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector; import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.EarlyTerminatingSortingCollector; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; @@ -50,7 +52,6 @@ import org.elasticsearch.search.suggest.SuggestPhase; import java.util.LinkedList; -import java.util.List; import static org.elasticsearch.search.query.QueryCollectorContext.createCancellableCollectorContext; import static org.elasticsearch.search.query.QueryCollectorContext.createEarlySortingTerminationCollectorContext; @@ -130,16 +131,17 @@ static boolean execute(SearchContext searchContext, final IndexSearcher searcher final ScrollContext scrollContext = searchContext.scrollContext(); if (scrollContext != null) { - if (returnsDocsInOrder(query, searchContext.sort())) { - if (scrollContext.totalHits == -1) { - // first round - assert scrollContext.lastEmittedDoc == null; - // there is not much that we can optimize here since we want to collect all - // documents in order to get the total number of hits - } else { + if (scrollContext.totalHits == -1) { + // first round + assert scrollContext.lastEmittedDoc == null; + // there is not much that we can optimize here since we want to collect all + // documents in order to get the total number of hits + + } else { + final ScoreDoc after = scrollContext.lastEmittedDoc; + if (returnsDocsInOrder(query, searchContext.sort())) { // now this gets interesting: since we sort in index-order, we can directly // skip to the desired doc - final ScoreDoc after = scrollContext.lastEmittedDoc; if (after != null) { BooleanQuery bq = new BooleanQuery.Builder() .add(query, BooleanClause.Occur.MUST) @@ -150,6 +152,17 @@ static boolean execute(SearchContext searchContext, final IndexSearcher searcher // ... and stop collecting after ${size} matches searchContext.terminateAfter(searchContext.size()); searchContext.trackTotalHits(false); + } else if (canEarlyTerminate(indexSort, searchContext)) { + // now this gets interesting: since the index sort matches the search sort, we can directly + // skip to the desired doc + if (after != null) { + BooleanQuery bq = new BooleanQuery.Builder() + .add(query, BooleanClause.Occur.MUST) + .add(new SearchAfterSortedDocQuery(indexSort, (FieldDoc) after), BooleanClause.Occur.FILTER) + .build(); + query = bq; + } + searchContext.trackTotalHits(false); } } } @@ -189,7 +202,10 @@ static boolean execute(SearchContext searchContext, final IndexSearcher searcher final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, reader, collectors.stream().anyMatch(QueryCollectorContext::shouldCollect)); final boolean shouldCollect = topDocsFactory.shouldCollect(); - if (scrollContext == null && topDocsFactory.numHits() > 0 && canEarlyTerminate(indexSort, searchContext)) { + + if (topDocsFactory.numHits() > 0 && + (scrollContext == null || scrollContext.totalHits != -1) && + canEarlyTerminate(indexSort, searchContext)) { // top docs collection can be early terminated based on index sort // add the collector context first so we don't early terminate aggs but only top docs collectors.addFirst(createEarlySortingTerminationCollectorContext(reader, searchContext.query(), indexSort, diff --git a/core/src/test/java/org/apache/lucene/queries/SearchAfterSortedDocQueryTests.java b/core/src/test/java/org/apache/lucene/queries/SearchAfterSortedDocQueryTests.java new file mode 100644 index 0000000000000..36abfea06ba4b --- /dev/null +++ b/core/src/test/java/org/apache/lucene/queries/SearchAfterSortedDocQueryTests.java @@ -0,0 +1,130 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.lucene.queries; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.SortedDocValuesField; +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.ReaderUtil; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.QueryUtils; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.SortedNumericSortField; +import org.apache.lucene.search.SortedSetSortField; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; + +public class SearchAfterSortedDocQueryTests extends ESTestCase { + + public void testBasics() { + Sort sort1 = new Sort( + new SortedNumericSortField("field1", SortField.Type.INT), + new SortedSetSortField("field2", false) + ); + Sort sort2 = new Sort( + new SortedNumericSortField("field1", SortField.Type.INT), + new SortedSetSortField("field3", false) + ); + FieldDoc fieldDoc1 = new FieldDoc(0, 0f, new Object[]{5, new BytesRef("foo")}); + FieldDoc fieldDoc2 = new FieldDoc(0, 0f, new Object[]{5, new BytesRef("foo")}); + + SearchAfterSortedDocQuery query1 = new SearchAfterSortedDocQuery(sort1, fieldDoc1); + SearchAfterSortedDocQuery query2 = new SearchAfterSortedDocQuery(sort1, fieldDoc2); + SearchAfterSortedDocQuery query3 = new SearchAfterSortedDocQuery(sort2, fieldDoc2); + QueryUtils.check(query1); + QueryUtils.checkEqual(query1, query2); + QueryUtils.checkUnequal(query1, query3); + } + + public void testInvalidSort() { + Sort sort = new Sort(new SortedNumericSortField("field1", SortField.Type.INT)); + FieldDoc fieldDoc = new FieldDoc(0, 0f, new Object[] {4, 5}); + IllegalArgumentException ex = + expectThrows(IllegalArgumentException.class, () -> new SearchAfterSortedDocQuery(sort, fieldDoc)); + assertThat(ex.getMessage(), equalTo("after doc has 2 value(s) but sort has 1.")); + } + + public void testRandom() throws IOException { + final int numDocs = randomIntBetween(100, 200); + final Document doc = new Document(); + final Directory dir = newDirectory(); + Sort sort = new Sort( + new SortedNumericSortField("number1", SortField.Type.INT), + new SortField("string", SortField.Type.STRING) + ); + final IndexWriterConfig config = new IndexWriterConfig(); + config.setIndexSort(sort); + final RandomIndexWriter w = new RandomIndexWriter(random(), dir, config); + for (int i = 0; i < numDocs; ++i) { + int rand = randomIntBetween(0, 10); + doc.add(new SortedNumericDocValuesField("number", rand)); + doc.add(new SortedDocValuesField("string", new BytesRef(randomAlphaOfLength(randomIntBetween(5, 50))))); + w.addDocument(doc); + doc.clear(); + if (rarely()) { + w.commit(); + } + } + final IndexReader reader = w.getReader(); + final IndexSearcher searcher = newSearcher(reader); + + int step = randomIntBetween(1, 10); + FixedBitSet bitSet = new FixedBitSet(numDocs); + TopDocs topDocs = null; + for (int i = 0; i < numDocs;) { + if (topDocs != null) { + FieldDoc after = (FieldDoc) topDocs.scoreDocs[topDocs.scoreDocs.length - 1]; + topDocs = searcher.search(new SearchAfterSortedDocQuery(sort, after), step, sort); + } else { + topDocs = searcher.search(new MatchAllDocsQuery(), step, sort); + } + i += step; + for (ScoreDoc topDoc : topDocs.scoreDocs) { + int readerIndex = ReaderUtil.subIndex(topDoc.doc, reader.leaves()); + final LeafReaderContext leafReaderContext = reader.leaves().get(readerIndex); + int docRebase = topDoc.doc - leafReaderContext.docBase; + if (leafReaderContext.reader().hasDeletions()) { + assertTrue(leafReaderContext.reader().getLiveDocs().get(docRebase)); + } + assertFalse(bitSet.get(topDoc.doc)); + bitSet.set(topDoc.doc); + } + } + assertThat(bitSet.cardinality(), equalTo(reader.numDocs())); + w.close(); + reader.close(); + dir.close(); + } +} diff --git a/core/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java b/core/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java index 2633cb706e0cc..b05c6dff04b6e 100644 --- a/core/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java +++ b/core/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java @@ -36,12 +36,12 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector; import org.apache.lucene.search.ConstantScoreQuery; +import org.apache.lucene.search.FieldComparator; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TermQuery; @@ -50,10 +50,8 @@ import org.apache.lucene.store.Directory; import org.elasticsearch.action.search.SearchTask; import org.elasticsearch.index.query.ParsedQuery; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.internal.ScrollContext; -import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.TestSearchContext; @@ -64,11 +62,9 @@ import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThan; -import static org.hamcrest.Matchers.nullValue; public class QueryPhaseTests extends ESTestCase { @@ -440,4 +436,71 @@ protected void search(List leaves, Weight weight, Collector c reader.close(); dir.close(); } + + public void testIndexSortScrollOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort( + new SortField("rank", SortField.Type.INT), + new SortField("tiebreaker", SortField.Type.INT) + ); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(100, 200); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + doc.add(new NumericDocValuesField("rank", random().nextInt())); + doc.add(new NumericDocValuesField("tiebreaker", i)); + w.addDocument(doc); + } + w.close(); + + TestSearchContext context = new TestSearchContext(null); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + ScrollContext scrollContext = new ScrollContext(); + scrollContext.lastEmittedDoc = null; + scrollContext.maxScore = Float.NaN; + scrollContext.totalHits = -1; + context.scrollContext(scrollContext); + context.setTask(new SearchTask(123L, "", "", "", null)); + context.setSize(10); + context.sort(new SortAndFormats(sort, new DocValueFormat[] {DocValueFormat.RAW, DocValueFormat.RAW})); + + final AtomicBoolean collected = new AtomicBoolean(); + final IndexReader reader = DirectoryReader.open(dir); + IndexSearcher contextSearcher = new IndexSearcher(reader) { + protected void search(List leaves, Weight weight, Collector collector) throws IOException { + collected.set(true); + super.search(leaves, weight, collector); + } + }; + + QueryPhase.execute(context, contextSearcher, sort); + assertThat(context.queryResult().topDocs().totalHits, equalTo(numDocs)); + assertTrue(collected.get()); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.terminateAfter(), equalTo(0)); + assertThat(context.queryResult().getTotalHits(), equalTo(numDocs)); + int sizeMinus1 = context.queryResult().topDocs().scoreDocs.length - 1; + FieldDoc lastDoc = (FieldDoc) context.queryResult().topDocs().scoreDocs[sizeMinus1]; + + QueryPhase.execute(context, contextSearcher, sort); + assertThat(context.queryResult().topDocs().totalHits, equalTo(numDocs)); + assertTrue(collected.get()); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.terminateAfter(), equalTo(0)); + assertThat(context.queryResult().getTotalHits(), equalTo(numDocs)); + FieldDoc firstDoc = (FieldDoc) context.queryResult().topDocs().scoreDocs[0]; + for (int i = 0; i < sort.getSort().length; i++) { + @SuppressWarnings("unchecked") + FieldComparator comparator = (FieldComparator) sort.getSort()[i].getComparator(1, i); + int cmp = comparator.compareValues(firstDoc.fields[i], lastDoc.fields[i]); + if (cmp == 0) { + continue; + } + assertThat(cmp, equalTo(1)); + break; + } + reader.close(); + dir.close(); + } }