From 0ef6ed98945d9698f0703283086883554a28dfb7 Mon Sep 17 00:00:00 2001 From: Adrien Grand Date: Tue, 3 Dec 2013 16:29:28 +0100 Subject: [PATCH] Fix _all boosting. _all boosting used to rely on the fact that the TokenStream doesn't eagerly consume the input java.io.Reader. This fixes the issue by using binary search in order to find the right boost given a token's start offset. Close #4315 --- .../common/lucene/all/AllEntries.java | 38 ++++++++++++-- .../common/lucene/all/AllTokenStream.java | 17 ++++--- .../common/lucene/all/SimpleAllTests.java | 51 +++++++++++++++++++ 3 files changed, 93 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/elasticsearch/common/lucene/all/AllEntries.java b/src/main/java/org/elasticsearch/common/lucene/all/AllEntries.java index 25237202323df..131ed5fda0922 100644 --- a/src/main/java/org/elasticsearch/common/lucene/all/AllEntries.java +++ b/src/main/java/org/elasticsearch/common/lucene/all/AllEntries.java @@ -40,14 +40,20 @@ public class AllEntries extends Reader { public static class Entry { private final String name; private final FastStringReader reader; + private final int startOffset; private final float boost; - public Entry(String name, FastStringReader reader, float boost) { + public Entry(String name, FastStringReader reader, int startOffset, float boost) { this.name = name; this.reader = reader; + this.startOffset = startOffset; this.boost = boost; } + public int startOffset() { + return startOffset; + } + public String name() { return this.name; } @@ -75,7 +81,15 @@ public void addText(String name, String text, float boost) { if (boost != 1.0f) { customBoost = true; } - Entry entry = new Entry(name, new FastStringReader(text), boost); + final int lastStartOffset; + if (entries.isEmpty()) { + lastStartOffset = -1; + } else { + final Entry last = entries.get(entries.size() - 1); + lastStartOffset = last.startOffset() + last.reader().length(); + } + final int startOffset = lastStartOffset + 1; // +1 because we insert a space between tokens + Entry entry = new Entry(name, new FastStringReader(text), startOffset, boost); entries.add(entry); } @@ -129,8 +143,22 @@ public Set fields() { return fields; } - public Entry current() { - return this.current; + // compute the boost for a token with the given startOffset + public float boost(int startOffset) { + int lo = 0, hi = entries.size() - 1; + while (lo <= hi) { + final int mid = (lo + hi) >>> 1; + final int midOffset = entries.get(mid).startOffset(); + if (startOffset < midOffset) { + hi = mid - 1; + } else { + lo = mid + 1; + } + } + final int index = Math.max(0, hi); // protection against broken token streams + assert entries.get(index).startOffset() <= startOffset; + assert index == entries.size() - 1 || entries.get(index + 1).startOffset() > startOffset; + return entries.get(index).boost(); } @Override @@ -186,7 +214,7 @@ public int read(char[] cbuf, int off, int len) throws IOException { @Override public void close() { if (current != null) { - current.reader().close(); + // no need to close, these are readers on strings current = null; } } diff --git a/src/main/java/org/elasticsearch/common/lucene/all/AllTokenStream.java b/src/main/java/org/elasticsearch/common/lucene/all/AllTokenStream.java index 04e5e77da9da8..3c461eceabb18 100644 --- a/src/main/java/org/elasticsearch/common/lucene/all/AllTokenStream.java +++ b/src/main/java/org/elasticsearch/common/lucene/all/AllTokenStream.java @@ -22,6 +22,7 @@ import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenFilter; import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; import org.apache.lucene.analysis.tokenattributes.PayloadAttribute; import org.apache.lucene.util.BytesRef; @@ -42,11 +43,13 @@ public static TokenStream allTokenStream(String allFieldName, AllEntries allEntr private final AllEntries allEntries; + private final OffsetAttribute offsetAttribute; private final PayloadAttribute payloadAttribute; AllTokenStream(TokenStream input, AllEntries allEntries) { super(input); this.allEntries = allEntries; + offsetAttribute = addAttribute(OffsetAttribute.class); payloadAttribute = addAttribute(PayloadAttribute.class); } @@ -59,14 +62,12 @@ public final boolean incrementToken() throws IOException { if (!input.incrementToken()) { return false; } - if (allEntries.current() != null) { - float boost = allEntries.current().boost(); - if (boost != 1.0f) { - encodeFloat(boost, payloadSpare.bytes, payloadSpare.offset); - payloadAttribute.setPayload(payloadSpare); - } else { - payloadAttribute.setPayload(null); - } + final float boost = allEntries.boost(offsetAttribute.startOffset()); + if (boost != 1.0f) { + encodeFloat(boost, payloadSpare.bytes, payloadSpare.offset); + payloadAttribute.setPayload(payloadSpare); + } else { + payloadAttribute.setPayload(null); } return true; } diff --git a/src/test/java/org/elasticsearch/common/lucene/all/SimpleAllTests.java b/src/test/java/org/elasticsearch/common/lucene/all/SimpleAllTests.java index 08e2a874f490e..0b39b2be2e9b6 100644 --- a/src/test/java/org/elasticsearch/common/lucene/all/SimpleAllTests.java +++ b/src/test/java/org/elasticsearch/common/lucene/all/SimpleAllTests.java @@ -19,6 +19,11 @@ package org.elasticsearch.common.lucene.all; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.core.WhitespaceAnalyzer; +import org.apache.lucene.analysis.payloads.PayloadHelper; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.PayloadAttribute; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.StoredField; @@ -27,6 +32,7 @@ import org.apache.lucene.search.*; import org.apache.lucene.store.Directory; import org.apache.lucene.store.RAMDirectory; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.test.ElasticsearchTestCase; import org.junit.Test; @@ -40,6 +46,51 @@ */ public class SimpleAllTests extends ElasticsearchTestCase { + @Test + public void testBoostOnEagerTokenizer() throws Exception { + AllEntries allEntries = new AllEntries(); + allEntries.addText("field1", "all", 2.0f); + allEntries.addText("field2", "your", 1.0f); + allEntries.addText("field1", "boosts", 0.5f); + allEntries.reset(); + // whitespace analyzer's tokenizer reads characters eagerly on the contrary to the standard tokenizer + final TokenStream ts = AllTokenStream.allTokenStream("any", allEntries, new WhitespaceAnalyzer(Lucene.VERSION)); + final CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); + final PayloadAttribute payloadAtt = ts.addAttribute(PayloadAttribute.class); + ts.reset(); + for (int i = 0; i < 3; ++i) { + assertTrue(ts.incrementToken()); + final String term; + final float boost; + switch (i) { + case 0: + term = "all"; + boost = 2; + break; + case 1: + term = "your"; + boost = 1; + break; + case 2: + term = "boosts"; + boost = 0.5f; + break; + default: + throw new AssertionError(); + } + assertEquals(term, termAtt.toString()); + final BytesRef payload = payloadAtt.getPayload(); + if (payload == null || payload.length == 0) { + assertEquals(boost, 1f, 0.001f); + } else { + assertEquals(4, payload.length); + final float b = PayloadHelper.decodeFloat(payload.bytes, payload.offset); + assertEquals(boost, b, 0.001f); + } + } + assertFalse(ts.incrementToken()); + } + @Test public void testAllEntriesRead() throws Exception { AllEntries allEntries = new AllEntries();