diff --git a/src/main/java/org/elasticsearch/common/lucene/all/AllEntries.java b/src/main/java/org/elasticsearch/common/lucene/all/AllEntries.java index 25237202323df..131ed5fda0922 100644 --- a/src/main/java/org/elasticsearch/common/lucene/all/AllEntries.java +++ b/src/main/java/org/elasticsearch/common/lucene/all/AllEntries.java @@ -40,14 +40,20 @@ public class AllEntries extends Reader { public static class Entry { private final String name; private final FastStringReader reader; + private final int startOffset; private final float boost; - public Entry(String name, FastStringReader reader, float boost) { + public Entry(String name, FastStringReader reader, int startOffset, float boost) { this.name = name; this.reader = reader; + this.startOffset = startOffset; this.boost = boost; } + public int startOffset() { + return startOffset; + } + public String name() { return this.name; } @@ -75,7 +81,15 @@ public void addText(String name, String text, float boost) { if (boost != 1.0f) { customBoost = true; } - Entry entry = new Entry(name, new FastStringReader(text), boost); + final int lastStartOffset; + if (entries.isEmpty()) { + lastStartOffset = -1; + } else { + final Entry last = entries.get(entries.size() - 1); + lastStartOffset = last.startOffset() + last.reader().length(); + } + final int startOffset = lastStartOffset + 1; // +1 because we insert a space between tokens + Entry entry = new Entry(name, new FastStringReader(text), startOffset, boost); entries.add(entry); } @@ -129,8 +143,22 @@ public Set fields() { return fields; } - public Entry current() { - return this.current; + // compute the boost for a token with the given startOffset + public float boost(int startOffset) { + int lo = 0, hi = entries.size() - 1; + while (lo <= hi) { + final int mid = (lo + hi) >>> 1; + final int midOffset = entries.get(mid).startOffset(); + if (startOffset < midOffset) { + hi = mid - 1; + } else { + lo = mid + 1; + } + } + final int index = Math.max(0, hi); // protection against broken token streams + assert entries.get(index).startOffset() <= startOffset; + assert index == entries.size() - 1 || entries.get(index + 1).startOffset() > startOffset; + return entries.get(index).boost(); } @Override @@ -186,7 +214,7 @@ public int read(char[] cbuf, int off, int len) throws IOException { @Override public void close() { if (current != null) { - current.reader().close(); + // no need to close, these are readers on strings current = null; } } diff --git a/src/main/java/org/elasticsearch/common/lucene/all/AllTokenStream.java b/src/main/java/org/elasticsearch/common/lucene/all/AllTokenStream.java index 04e5e77da9da8..3c461eceabb18 100644 --- a/src/main/java/org/elasticsearch/common/lucene/all/AllTokenStream.java +++ b/src/main/java/org/elasticsearch/common/lucene/all/AllTokenStream.java @@ -22,6 +22,7 @@ import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenFilter; import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; import org.apache.lucene.analysis.tokenattributes.PayloadAttribute; import org.apache.lucene.util.BytesRef; @@ -42,11 +43,13 @@ public static TokenStream allTokenStream(String allFieldName, AllEntries allEntr private final AllEntries allEntries; + private final OffsetAttribute offsetAttribute; private final PayloadAttribute payloadAttribute; AllTokenStream(TokenStream input, AllEntries allEntries) { super(input); this.allEntries = allEntries; + offsetAttribute = addAttribute(OffsetAttribute.class); payloadAttribute = addAttribute(PayloadAttribute.class); } @@ -59,14 +62,12 @@ public final boolean incrementToken() throws IOException { if (!input.incrementToken()) { return false; } - if (allEntries.current() != null) { - float boost = allEntries.current().boost(); - if (boost != 1.0f) { - encodeFloat(boost, payloadSpare.bytes, payloadSpare.offset); - payloadAttribute.setPayload(payloadSpare); - } else { - payloadAttribute.setPayload(null); - } + final float boost = allEntries.boost(offsetAttribute.startOffset()); + if (boost != 1.0f) { + encodeFloat(boost, payloadSpare.bytes, payloadSpare.offset); + payloadAttribute.setPayload(payloadSpare); + } else { + payloadAttribute.setPayload(null); } return true; } diff --git a/src/test/java/org/elasticsearch/common/lucene/all/SimpleAllTests.java b/src/test/java/org/elasticsearch/common/lucene/all/SimpleAllTests.java index 08e2a874f490e..57de6dbbe2739 100644 --- a/src/test/java/org/elasticsearch/common/lucene/all/SimpleAllTests.java +++ b/src/test/java/org/elasticsearch/common/lucene/all/SimpleAllTests.java @@ -19,6 +19,11 @@ package org.elasticsearch.common.lucene.all; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.core.WhitespaceAnalyzer; +import org.apache.lucene.analysis.payloads.PayloadHelper; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.PayloadAttribute; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.StoredField; @@ -27,6 +32,7 @@ import org.apache.lucene.search.*; import org.apache.lucene.store.Directory; import org.apache.lucene.store.RAMDirectory; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.test.ElasticsearchTestCase; import org.junit.Test; @@ -40,6 +46,52 @@ */ public class SimpleAllTests extends ElasticsearchTestCase { + @Test + // https://github.com/elasticsearch/elasticsearch/issues/4315 + public void testBoostOnEagerTokenizer() throws Exception { + AllEntries allEntries = new AllEntries(); + allEntries.addText("field1", "all", 2.0f); + allEntries.addText("field2", "your", 1.0f); + allEntries.addText("field1", "boosts", 0.5f); + allEntries.reset(); + // whitespace analyzer's tokenizer reads characters eagerly on the contrary to the standard tokenizer + final TokenStream ts = AllTokenStream.allTokenStream("any", allEntries, new WhitespaceAnalyzer(Lucene.VERSION)); + final CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); + final PayloadAttribute payloadAtt = ts.addAttribute(PayloadAttribute.class); + ts.reset(); + for (int i = 0; i < 3; ++i) { + assertTrue(ts.incrementToken()); + final String term; + final float boost; + switch (i) { + case 0: + term = "all"; + boost = 2; + break; + case 1: + term = "your"; + boost = 1; + break; + case 2: + term = "boosts"; + boost = 0.5f; + break; + default: + throw new AssertionError(); + } + assertEquals(term, termAtt.toString()); + final BytesRef payload = payloadAtt.getPayload(); + if (payload == null || payload.length == 0) { + assertEquals(boost, 1f, 0.001f); + } else { + assertEquals(4, payload.length); + final float b = PayloadHelper.decodeFloat(payload.bytes, payload.offset); + assertEquals(boost, b, 0.001f); + } + } + assertFalse(ts.incrementToken()); + } + @Test public void testAllEntriesRead() throws Exception { AllEntries allEntries = new AllEntries();