Skip to content

Commit

Permalink
ES|QL categorize v1 (elastic#112860)
Browse files Browse the repository at this point in the history
* Prepare TokenListCategorizer for usage in ES|QL

* Expose text categorization from ML module

* Let esql plugin depend on ml plugin

* Fix/suppress this-escape

* Incomplete v1 of ES|QL Categorize

* Refactor / remove CategorizeInternal
  • Loading branch information
jan-elastic authored Sep 17, 2024
1 parent 54435b1 commit 71b30ce
Show file tree
Hide file tree
Showing 18 changed files with 423 additions and 9 deletions.
3 changes: 2 additions & 1 deletion x-pack/plugin/esql/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ esplugin {
name 'x-pack-esql'
description 'The plugin that powers ESQL for Elasticsearch'
classname 'org.elasticsearch.xpack.esql.plugin.EsqlPlugin'
extendedPlugins = ['x-pack-esql-core', 'lang-painless']
extendedPlugins = ['x-pack-esql-core', 'lang-painless', 'x-pack-ml']
}

base {
Expand All @@ -22,6 +22,7 @@ dependencies {
compileOnly project(path: xpackModule('core'))
compileOnly project(':modules:lang-painless:spi')
compileOnly project(xpackModule('esql-core'))
compileOnly project(xpackModule('ml'))
implementation project('compute')
implementation project('compute:ann')
implementation project(':libs:elasticsearch-dissect')
Expand Down
3 changes: 3 additions & 0 deletions x-pack/plugin/esql/compute/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ base {
dependencies {
compileOnly project(':server')
compileOnly project('ann')
compileOnly project(xpackModule('ml'))
annotationProcessor project('gen')
implementation 'com.carrotsearch:hppc:0.8.1'

testImplementation project(':test:framework')
testImplementation(project(xpackModule('esql-core')))
testImplementation(project(xpackModule('core')))
testImplementation(project(xpackModule('ml')))
}

def projectDirectory = project.layout.projectDirectory
Expand Down
2 changes: 2 additions & 0 deletions x-pack/plugin/esql/compute/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

module org.elasticsearch.compute {

requires org.apache.lucene.analysis.common;
requires org.apache.lucene.core;
requires org.elasticsearch.base;
requires org.elasticsearch.server;
Expand All @@ -15,6 +16,7 @@
// required due to dependency on org.elasticsearch.common.util.concurrent.AbstractAsyncTask
requires org.apache.logging.log4j;
requires org.elasticsearch.logging;
requires org.elasticsearch.ml;
requires org.elasticsearch.tdigest;
requires org.elasticsearch.geo;
requires hppc;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
categorize
required_capability: categorize

FROM sample_data
| STATS count=COUNT(), values=VALUES(message) BY category=CATEGORIZE(message)
| SORT count DESC, category ASC
;

count:long | values:keyword | category:integer
3 | [Connected to 10.1.0.1, Connected to 10.1.0.2, Connected to 10.1.0.3] | 0
3 | [Connection error] | 1
1 | [Disconnected] | 2
;

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.aggregate.WeightedAvg;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
Expand Down Expand Up @@ -383,7 +384,10 @@ private FunctionDefinition[][] functions() {
}

private static FunctionDefinition[][] snapshotFunctions() {
return new FunctionDefinition[][] { new FunctionDefinition[] { def(Rate.class, Rate::withUnresolvedTimestamp, "rate") } };
return new FunctionDefinition[][] {
new FunctionDefinition[] {
def(Categorize.class, Categorize::new, "categorize"),
def(Rate.class, Rate::withUnresolvedTimestamp, "rate") } };
}

public EsqlFunctionRegistry snapshotRegistry() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.grouping;

import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.compute.ann.Evaluator;
import org.elasticsearch.compute.ann.Fixed;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.index.analysis.CharFilterFactory;
import org.elasticsearch.index.analysis.CustomAnalyzer;
import org.elasticsearch.index.analysis.TokenFilterFactory;
import org.elasticsearch.index.analysis.TokenizerFactory;
import org.elasticsearch.xpack.esql.capabilities.Validatable;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;

import java.io.IOException;
import java.util.List;
import java.util.function.Function;

import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString;

/**
* Categorizes text messages.
*
* This implementation is incomplete and comes with the following caveats:
* - it only works correctly on a single node.
* - when running on multiple nodes, category IDs of the different nodes are
* aggregated, even though the same ID can correspond to a totally different
* category
* - the output consists of category IDs, which should be replaced by category
* regexes or keys
*
* TODO(jan, nik): fix this
*/
public class Categorize extends GroupingFunction implements Validatable {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Expression.class,
"Categorize",
Categorize::new
);

private final Expression field;

@FunctionInfo(returnType = { "integer" }, description = "Categorizes text messages")
public Categorize(
Source source,
@Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field
) {
super(source, List.of(field));
this.field = field;
}

private Categorize(StreamInput in) throws IOException {
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class));
}

@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteable(field);
}

@Override
public String getWriteableName() {
return ENTRY.name;
}

@Override
public boolean foldable() {
return field.foldable();
}

@Evaluator
static int process(
BytesRef v,
@Fixed(includeInToString = false, build = true) CategorizationAnalyzer analyzer,
@Fixed(includeInToString = false, build = true) TokenListCategorizer.CloseableTokenListCategorizer categorizer
) {
String s = v.utf8ToString();
try (TokenStream ts = analyzer.tokenStream("text", s)) {
return categorizer.computeCategory(ts, s.length(), 1).getId();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

@Override
public ExpressionEvaluator.Factory toEvaluator(Function<Expression, ExpressionEvaluator.Factory> toEvaluator) {
return new CategorizeEvaluator.Factory(
source(),
toEvaluator.apply(field),
context -> new CategorizationAnalyzer(
// TODO(jan): get the correct analyzer in here, see CategorizationAnalyzerConfig::buildStandardCategorizationAnalyzer
new CustomAnalyzer(
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
new CharFilterFactory[0],
new TokenFilterFactory[0]
),
true
),
context -> new TokenListCategorizer.CloseableTokenListCategorizer(
new CategorizationBytesRefHash(new BytesRefHash(2048, context.bigArrays())),
CategorizationPartOfSpeechDictionary.getInstance(),
0.70f
)
);
}

@Override
protected TypeResolution resolveType() {
return isString(field(), sourceText(), DEFAULT);
}

@Override
public DataType dataType() {
return DataType.INTEGER;
}

@Override
public Expression replaceChildren(List<Expression> newChildren) {
return new Categorize(source(), newChildren.get(0));
}

@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, Categorize::new, field);
}

public Expression field() {
return field;
}

@Override
public String toString() {
return "Categorize{field=" + field + "}";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
Expand Down Expand Up @@ -73,6 +74,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
entries.add(Atan2.ENTRY);
entries.add(Bucket.ENTRY);
entries.add(Case.ENTRY);
entries.add(Categorize.ENTRY);
entries.add(CIDRMatch.ENTRY);
entries.add(Coalesce.ENTRY);
entries.add(Concat.ENTRY);
Expand Down
Loading

0 comments on commit 71b30ce

Please sign in to comment.