Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ES|QL categorize v1 #112860

Merged
merged 6 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion x-pack/plugin/esql/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ esplugin {
name 'x-pack-esql'
description 'The plugin that powers ESQL for Elasticsearch'
classname 'org.elasticsearch.xpack.esql.plugin.EsqlPlugin'
extendedPlugins = ['x-pack-esql-core', 'lang-painless']
extendedPlugins = ['x-pack-esql-core', 'lang-painless', 'x-pack-ml']
}

base {
Expand All @@ -22,6 +22,7 @@ dependencies {
compileOnly project(path: xpackModule('core'))
compileOnly project(':modules:lang-painless:spi')
compileOnly project(xpackModule('esql-core'))
compileOnly project(xpackModule('ml'))
implementation project('compute')
implementation project('compute:ann')
implementation project(':libs:elasticsearch-dissect')
Expand Down
3 changes: 3 additions & 0 deletions x-pack/plugin/esql/compute/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ base {
dependencies {
compileOnly project(':server')
compileOnly project('ann')
compileOnly project(xpackModule('ml'))
annotationProcessor project('gen')
implementation 'com.carrotsearch:hppc:0.8.1'

testImplementation project(':test:framework')
testImplementation(project(xpackModule('esql-core')))
testImplementation(project(xpackModule('core')))
testImplementation(project(xpackModule('ml')))
}

def projectDirectory = project.layout.projectDirectory
Expand Down
2 changes: 2 additions & 0 deletions x-pack/plugin/esql/compute/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

module org.elasticsearch.compute {

requires org.apache.lucene.analysis.common;
requires org.apache.lucene.core;
requires org.elasticsearch.base;
requires org.elasticsearch.server;
Expand All @@ -15,6 +16,7 @@
// required due to dependency on org.elasticsearch.common.util.concurrent.AbstractAsyncTask
requires org.apache.logging.log4j;
requires org.elasticsearch.logging;
requires org.elasticsearch.ml;
requires org.elasticsearch.tdigest;
requires org.elasticsearch.geo;
requires hppc;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
categorize
required_capability: categorize

FROM sample_data
| STATS count=COUNT(), values=VALUES(message) BY category=CATEGORIZE(message)
| SORT count DESC, category ASC
;

count:long | values:keyword | category:integer
3 | [Connected to 10.1.0.1, Connected to 10.1.0.2, Connected to 10.1.0.3] | 0
3 | [Connection error] | 1
1 | [Disconnected] | 2
;

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.aggregate.WeightedAvg;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
Expand Down Expand Up @@ -383,7 +384,10 @@ private FunctionDefinition[][] functions() {
}

private static FunctionDefinition[][] snapshotFunctions() {
return new FunctionDefinition[][] { new FunctionDefinition[] { def(Rate.class, Rate::withUnresolvedTimestamp, "rate") } };
return new FunctionDefinition[][] {
new FunctionDefinition[] {
def(Categorize.class, Categorize::new, "categorize"),
def(Rate.class, Rate::withUnresolvedTimestamp, "rate") } };
}

public EsqlFunctionRegistry snapshotRegistry() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.grouping;

import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.compute.ann.Evaluator;
import org.elasticsearch.compute.ann.Fixed;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.index.analysis.CharFilterFactory;
import org.elasticsearch.index.analysis.CustomAnalyzer;
import org.elasticsearch.index.analysis.TokenFilterFactory;
import org.elasticsearch.index.analysis.TokenizerFactory;
import org.elasticsearch.xpack.esql.capabilities.Validatable;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;

import java.io.IOException;
import java.util.List;
import java.util.function.Function;

import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString;

/**
* Categorizes text messages.
*
* This implementation is incomplete and comes with the following caveats:
* - it only works correctly on a single node.
* - when running on multiple nodes, category IDs of the different nodes are
* aggregated, even though the same ID can correspond to a totally different
* category
* - the output consists of category IDs, which should be replaced by category
* regexes or keys
*
* TODO(jan, nik): fix this
*/
public class Categorize extends GroupingFunction implements Validatable {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Expression.class,
"Categorize",
Categorize::new
);

private final Expression field;

@FunctionInfo(returnType = { "integer" }, description = "Categorizes text messages")
public Categorize(
Source source,
@Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field
) {
super(source, List.of(field));
this.field = field;
}

private Categorize(StreamInput in) throws IOException {
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class));
}

@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteable(field);
}

@Override
public String getWriteableName() {
return ENTRY.name;
}

@Override
public boolean foldable() {
return field.foldable();
}

@Evaluator
static int process(
BytesRef v,
@Fixed(includeInToString = false, build = true) CategorizationAnalyzer analyzer,
@Fixed(includeInToString = false, build = true) TokenListCategorizer.CloseableTokenListCategorizer categorizer
) {
String s = v.utf8ToString();
try (TokenStream ts = analyzer.tokenStream("text", s)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's super expensive to do all this. But such is life.

I spent a little looking and am pretty sure there's a nice way to make a Reader that works on the BytesRef directly and you don't need to make a String here. I couldn't find anything easy to just plug in, so I think it can wait.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out. I'll add it to the to do list.

return categorizer.computeCategory(ts, s.length(), 1).getId();
} catch (IOException e) {
throw new RuntimeException(e);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd talked about allows IOException from a lot of things in ESQL - we do read from files and such. But we never go around to it and I'm not sure we'd do it for process anyway. And this one isn't real anyway - at least, sort of not. It's throw if the token stream fails to do token stream things, but that's not real IO. Bleh. This is totally fine.

}

@Override
public ExpressionEvaluator.Factory toEvaluator(Function<Expression, ExpressionEvaluator.Factory> toEvaluator) {
return new CategorizeEvaluator.Factory(
source(),
toEvaluator.apply(field),
context -> new CategorizationAnalyzer(
// TODO(jan): get the correct analyzer in here, see CategorizationAnalyzerConfig::buildStandardCategorizationAnalyzer
new CustomAnalyzer(
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
new CharFilterFactory[0],
new TokenFilterFactory[0]
),
true
),
context -> new TokenListCategorizer.CloseableTokenListCategorizer(
new CategorizationBytesRefHash(new BytesRefHash(2048, context.bigArrays())),
CategorizationPartOfSpeechDictionary.getInstance(),
0.70f
)
);
}

@Override
protected TypeResolution resolveType() {
return isString(field(), sourceText(), DEFAULT);
}

@Override
public DataType dataType() {
return DataType.INTEGER;
}

@Override
public Expression replaceChildren(List<Expression> newChildren) {
return new Categorize(source(), newChildren.get(0));
}

@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, Categorize::new, field);
}

public Expression field() {
return field;
}

@Override
public String toString() {
return "Categorize{field=" + field + "}";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
Expand Down Expand Up @@ -73,6 +74,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
entries.add(Atan2.ENTRY);
entries.add(Bucket.ENTRY);
entries.add(Case.ENTRY);
entries.add(Categorize.ENTRY);
entries.add(CIDRMatch.ENTRY);
entries.add(Coalesce.ENTRY);
entries.add(Concat.ENTRY);
Expand Down
Loading
Loading