Skip to content

Commit

Permalink
Add term query for keyword script fields (#59372)
Browse files Browse the repository at this point in the history
This adds what I think is just about the simplest possible `term` query
implementation for `keyword` script fields and wires it into the field
mapper that we build for them.
  • Loading branch information
nik9000 authored Jul 13, 2020
1 parent fa48ccd commit 2659648
Show file tree
Hide file tree
Showing 11 changed files with 291 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import org.elasticsearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

public abstract class StringScriptFieldScript extends AbstractScriptFieldScript {
public static final ScriptContext<Factory> CONTEXT = new ScriptContext<>("string_script_field", Factory.class);
Expand All @@ -32,14 +32,26 @@ public interface Factory extends ScriptFactory {
}

public interface LeafFactory {
StringScriptFieldScript newInstance(LeafReaderContext ctx, Consumer<String> sync) throws IOException;
StringScriptFieldScript newInstance(LeafReaderContext ctx) throws IOException;
}

private final Consumer<String> sync;
private final List<String> results = new ArrayList<>();

public StringScriptFieldScript(Map<String, Object> params, SearchLookup searchLookup, LeafReaderContext ctx, Consumer<String> sync) {
public StringScriptFieldScript(Map<String, Object> params, SearchLookup searchLookup, LeafReaderContext ctx) {
super(params, searchLookup, ctx);
this.sync = sync;
}

/**
* Execute the script for the provided {@code docId}.
* <p>
* @return a mutable {@link List} that contains the results of the script
* and will be modified the next time you call {@linkplain #resultsForDoc}.
*/
public final List<String> resultsForDoc(int docId) {
results.clear();
setDocument(docId);
execute();
return results;
}

public static class Value {
Expand All @@ -50,7 +62,7 @@ public Value(StringScriptFieldScript script) {
}

public void value(String v) {
script.sync.accept(v);
script.results.add(v);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,26 @@
import org.elasticsearch.index.fielddata.SortingBinaryDocValues;
import org.elasticsearch.xpack.runtimefields.StringScriptFieldScript;

public final class ScriptBinaryDocValues extends SortingBinaryDocValues {
import java.util.List;

public final class ScriptBinaryDocValues extends SortingBinaryDocValues {
private final StringScriptFieldScript script;
private final ScriptBinaryFieldData.ScriptBinaryResult scriptBinaryResult;

ScriptBinaryDocValues(StringScriptFieldScript script, ScriptBinaryFieldData.ScriptBinaryResult scriptBinaryResult) {
ScriptBinaryDocValues(StringScriptFieldScript script) {
this.script = script;
this.scriptBinaryResult = scriptBinaryResult;
}

@Override
public boolean advanceExact(int doc) {
script.setDocument(doc);
script.execute();

count = scriptBinaryResult.getResult().size();
public boolean advanceExact(int docId) {
List<String> results = script.resultsForDoc(docId);
count = results.size();
if (count == 0) {
grow();
return false;
}

grow();
int i = 0;
for (String value : scriptBinaryResult.getResult()) {
grow();
for (String value : results) {
values[i++].copyChars(value);
}
sort();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,7 @@ public ScriptBinaryLeafFieldData load(LeafReaderContext context) {

@Override
public ScriptBinaryLeafFieldData loadDirect(LeafReaderContext context) throws IOException {
ScriptBinaryResult scriptBinaryResult = new ScriptBinaryResult();
return new ScriptBinaryLeafFieldData(
new ScriptBinaryDocValues(leafFactory.get().newInstance(context, scriptBinaryResult::accept), scriptBinaryResult)
);
return new ScriptBinaryLeafFieldData(new ScriptBinaryDocValues(leafFactory.get().newInstance(context)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,21 @@

import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.ToXContent.Params;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.TextSearchInfo;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.script.Script;
import org.elasticsearch.xpack.runtimefields.StringScriptFieldScript;
import org.elasticsearch.xpack.runtimefields.fielddata.ScriptBinaryFieldData;
import org.elasticsearch.xpack.runtimefields.query.StringScriptFieldTermQuery;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public final class RuntimeKeywordMappedFieldType extends MappedFieldType {

Expand Down Expand Up @@ -57,7 +60,11 @@ public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName) {

@Override
public Query termQuery(Object value, QueryShardContext context) {
return null;
return new StringScriptFieldTermQuery(
scriptFactory.newFactory(script.getParams(), context.lookup()),
name(),
BytesRefs.toString(Objects.requireNonNull(value))
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.runtimefields.query;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.elasticsearch.xpack.runtimefields.StringScriptFieldScript;

import java.io.IOException;
import java.util.Objects;

public class StringScriptFieldTermQuery extends Query {
private final StringScriptFieldScript.LeafFactory leafFactory;
private final String fieldName;
private final String term;

public StringScriptFieldTermQuery(StringScriptFieldScript.LeafFactory leafFactory, String fieldName, String term) {
this.leafFactory = leafFactory;
this.fieldName = fieldName;
this.term = term;
}

@Override
public final Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return new ConstantScoreWeight(this, boost) {
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false; // scripts aren't really cacheable at this point
}

@Override
public Scorer scorer(LeafReaderContext ctx) throws IOException {
StringScriptFieldScript script = leafFactory.newInstance(ctx);
DocIdSetIterator approximation = DocIdSetIterator.all(ctx.reader().maxDoc());
TwoPhaseIterator twoPhase = new TwoPhaseIterator(approximation) {
@Override
public boolean matches() throws IOException {
for (String result : script.resultsForDoc(approximation().docID())) {
if (term.equals(result)) {
return true;
}
}
return false;
}

@Override
public float matchCost() {
// TODO we don't have a good way of estimating the complexity of the script so we just go with 9000
return 9000f;
}
};
return new ConstantScoreScorer(this, score(), scoreMode, twoPhase);
}
};
}

@Override
public void visit(QueryVisitor visitor) {
visitor.consumeTerms(this, new Term(fieldName, term));
}

@Override
public final String toString(String field) {
if (fieldName.contentEquals(field)) {
return term;
}
return fieldName + ":" + term;
}

@Override
public int hashCode() {
return Objects.hash(fieldName, term);
}

@Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
StringScriptFieldTermQuery other = (StringScriptFieldTermQuery) obj;
return fieldName.equals(other.fieldName) && term.equals(other.term);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
import org.elasticsearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.IntFunction;

import static org.hamcrest.Matchers.equalTo;

public class DoubleScriptFieldScriptTests extends ScriptFieldScriptTestCase<
DoubleScriptFieldScript,
DoubleScriptFieldScript.Factory,
DoubleScriptFieldScript.LeafFactory,
Double> {
Expand Down Expand Up @@ -112,11 +113,15 @@ protected DoubleScriptFieldScript.LeafFactory newLeafFactory(
}

@Override
protected DoubleScriptFieldScript newInstance(
DoubleScriptFieldScript.LeafFactory leafFactory,
LeafReaderContext context,
List<Double> result
) throws IOException {
return leafFactory.newInstance(context, result::add);
protected IntFunction<List<Double>> newInstance(DoubleScriptFieldScript.LeafFactory leafFactory, LeafReaderContext context)
throws IOException {
List<Double> results = new ArrayList<>();
DoubleScriptFieldScript script = leafFactory.newInstance(context, results::add);
return docId -> {
results.clear();
script.setDocument(docId);
script.execute();
return results;
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
import org.elasticsearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.IntFunction;

import static org.hamcrest.Matchers.equalTo;

public class LongScriptFieldScriptTests extends ScriptFieldScriptTestCase<
LongScriptFieldScript,
LongScriptFieldScript.Factory,
LongScriptFieldScript.LeafFactory,
Long> {
Expand Down Expand Up @@ -98,9 +99,15 @@ protected LongScriptFieldScript.LeafFactory newLeafFactory(
}

@Override
protected LongScriptFieldScript newInstance(LongScriptFieldScript.LeafFactory leafFactory, LeafReaderContext context, List<Long> result)
protected IntFunction<List<Long>> newInstance(LongScriptFieldScript.LeafFactory leafFactory, LeafReaderContext context)
throws IOException {

return leafFactory.newInstance(context, result::add);
List<Long> results = new ArrayList<>();
LongScriptFieldScript script = leafFactory.newInstance(context, results::add);
return docId -> {
results.clear();
script.setDocument(docId);
script.execute();
return results;
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,17 @@
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.IntFunction;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public abstract class ScriptFieldScriptTestCase<S extends AbstractScriptFieldScript, F, LF, R> extends ESTestCase {
public abstract class ScriptFieldScriptTestCase<F, LF, R> extends ESTestCase {
protected abstract ScriptContext<F> scriptContext();

protected abstract LF newLeafFactory(F factory, Map<String, Object> params, SearchLookup searchLookup);

protected abstract S newInstance(LF leafFactory, LeafReaderContext context, List<R> results) throws IOException;
protected abstract IntFunction<List<R>> newInstance(LF leafFactory, LeafReaderContext context) throws IOException;

protected final List<R> execute(CheckedConsumer<RandomIndexWriter, IOException> indexBuilder, String script, MappedFieldType... types)
throws IOException {
Expand Down Expand Up @@ -92,15 +93,14 @@ public ScoreMode scoreMode() {

@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
S compiled = newInstance(leafFactory, context, result);
IntFunction<List<R>> compiled = newInstance(leafFactory, context);
return new LeafCollector() {
@Override
public void setScorer(Scorable scorer) {}

@Override
public void collect(int doc) {
compiled.setDocument(doc);
compiled.execute();
public void collect(int docId) {
result.addAll(compiled.apply(docId));
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.function.IntFunction;

import static org.hamcrest.Matchers.equalTo;

public class StringScriptFieldScriptTests extends ScriptFieldScriptTestCase<
StringScriptFieldScript,
StringScriptFieldScript.Factory,
StringScriptFieldScript.LeafFactory,
String> {
Expand Down Expand Up @@ -104,11 +104,8 @@ protected StringScriptFieldScript.LeafFactory newLeafFactory(
}

@Override
protected StringScriptFieldScript newInstance(
StringScriptFieldScript.LeafFactory leafFactory,
LeafReaderContext context,
List<String> result
) throws IOException {
return leafFactory.newInstance(context, result::add);
protected IntFunction<List<String>> newInstance(StringScriptFieldScript.LeafFactory leafFactory, LeafReaderContext context)
throws IOException {
return leafFactory.newInstance(context)::resultsForDoc;
}
}
Loading

0 comments on commit 2659648

Please sign in to comment.