Skip to content

Commit

Permalink
Refactor InlineStats to prefer composition
Browse files Browse the repository at this point in the history
Use a nested Aggregate to simplify InlineStats definition and resolution

Relates elastic#112266
  • Loading branch information
costin committed Aug 28, 2024
1 parent fff5c8f commit 076425d
Show file tree
Hide file tree
Showing 13 changed files with 92 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
import org.elasticsearch.xpack.esql.plan.TableIdentifier;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Drop;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
Expand All @@ -78,7 +79,6 @@
import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.Rename;
import org.elasticsearch.xpack.esql.plan.logical.Stats;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
Expand Down Expand Up @@ -117,7 +117,6 @@
import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD;
import static org.elasticsearch.xpack.esql.core.type.DataType.LONG;
import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT;
import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG;
import static org.elasticsearch.xpack.esql.core.type.DataType.VERSION;
import static org.elasticsearch.xpack.esql.core.type.DataType.isTemporalAmount;
import static org.elasticsearch.xpack.esql.stats.FeatureMetric.LIMIT;
Expand Down Expand Up @@ -406,8 +405,8 @@ protected LogicalPlan doRule(LogicalPlan plan) {
childrenOutput.addAll(output);
}

if (plan instanceof Stats stats) {
return resolveStats(stats, childrenOutput);
if (plan instanceof Aggregate aggregate) {
return resolveAggregate(aggregate, childrenOutput);
}

if (plan instanceof Drop d) {
Expand Down Expand Up @@ -441,12 +440,12 @@ protected LogicalPlan doRule(LogicalPlan plan) {
return plan.transformExpressionsOnly(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
}

private LogicalPlan resolveStats(Stats stats, List<Attribute> childrenOutput) {
private Aggregate resolveAggregate(Aggregate aggregate, List<Attribute> childrenOutput) {
// if the grouping is resolved but the aggs are not, use the former to resolve the latter
// e.g. STATS a ... GROUP BY a = x + 1
Holder<Boolean> changed = new Holder<>(false);
List<Expression> groupings = stats.groupings();
List<? extends NamedExpression> aggregates = stats.aggregates();
List<Expression> groupings = aggregate.groupings();
List<? extends NamedExpression> aggregates = aggregate.aggregates();
// first resolve groupings since the aggs might refer to them
// trying to globally resolve unresolved attributes will lead to some being marked as unresolvable
if (Resolvables.resolved(groupings) == false) {
Expand All @@ -460,7 +459,7 @@ private LogicalPlan resolveStats(Stats stats, List<Attribute> childrenOutput) {
}
groupings = newGroupings;
if (changed.get()) {
stats = stats.with(stats.child(), newGroupings, stats.aggregates());
aggregate = aggregate.with(aggregate.child(), newGroupings, aggregate.aggregates());
changed.set(false);
}
}
Expand All @@ -476,8 +475,8 @@ private LogicalPlan resolveStats(Stats stats, List<Attribute> childrenOutput) {
List<Attribute> resolvedList = NamedExpressions.mergeOutputAttributes(resolved, childrenOutput);

List<NamedExpression> newAggregates = new ArrayList<>();
for (NamedExpression aggregate : stats.aggregates()) {
var agg = (NamedExpression) aggregate.transformUp(UnresolvedAttribute.class, ua -> {
for (NamedExpression ag : aggregate.aggregates()) {
var agg = (NamedExpression) ag.transformUp(UnresolvedAttribute.class, ua -> {
Expression ne = ua;
Attribute maybeResolved = maybeResolveAttribute(ua, resolvedList);
if (maybeResolved != null) {
Expand All @@ -489,10 +488,10 @@ private LogicalPlan resolveStats(Stats stats, List<Attribute> childrenOutput) {
newAggregates.add(agg);
}

stats = changed.get() ? stats.with(stats.child(), groupings, newAggregates) : stats;
aggregate = changed.get() ? aggregate.with(aggregate.child(), groupings, newAggregates) : aggregate;
}

return (LogicalPlan) stats;
return aggregate;
}

private LogicalPlan resolveMvExpand(MvExpand p, List<Attribute> childrenOutput) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@
import org.elasticsearch.xpack.esql.optimizer.rules.PushDownEval;
import org.elasticsearch.xpack.esql.optimizer.rules.PushDownRegexExtract;
import org.elasticsearch.xpack.esql.optimizer.rules.RemoveStatsOverride;
import org.elasticsearch.xpack.esql.optimizer.rules.ReplaceAggregateAggExpressionWithEval;
import org.elasticsearch.xpack.esql.optimizer.rules.ReplaceAggregateNestedExpressionWithEval;
import org.elasticsearch.xpack.esql.optimizer.rules.ReplaceAliasingEvalWithProject;
import org.elasticsearch.xpack.esql.optimizer.rules.ReplaceLimitAndSortAsTopN;
import org.elasticsearch.xpack.esql.optimizer.rules.ReplaceLookupWithJoin;
import org.elasticsearch.xpack.esql.optimizer.rules.ReplaceOrderByExpressionWithEval;
import org.elasticsearch.xpack.esql.optimizer.rules.ReplaceRegexMatch;
import org.elasticsearch.xpack.esql.optimizer.rules.ReplaceStatsAggExpressionWithEval;
import org.elasticsearch.xpack.esql.optimizer.rules.ReplaceStatsNestedExpressionWithEval;
import org.elasticsearch.xpack.esql.optimizer.rules.ReplaceTrivialTypeConversions;
import org.elasticsearch.xpack.esql.optimizer.rules.SetAsOptimized;
import org.elasticsearch.xpack.esql.optimizer.rules.SimplifyComparisonsArithmetics;
Expand Down Expand Up @@ -171,14 +171,14 @@ protected static Batch<LogicalPlan> substitutions() {
new ReplaceLookupWithJoin(),
new RemoveStatsOverride(),
// first extract nested expressions inside aggs
new ReplaceStatsNestedExpressionWithEval(),
new ReplaceAggregateNestedExpressionWithEval(),
// then extract nested aggs top-level
new ReplaceStatsAggExpressionWithEval(),
new ReplaceAggregateAggExpressionWithEval(),
// lastly replace surrogate functions
new SubstituteSurrogates(),
// translate metric aggregates after surrogate substitution and replace nested expressions with eval (again)
new TranslateMetricsAggregate(),
new ReplaceStatsNestedExpressionWithEval(),
new ReplaceAggregateNestedExpressionWithEval(),
new ReplaceRegexMatch(),
new ReplaceTrivialTypeConversions(),
new ReplaceAliasingEvalWithProject(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ public CombineProjections() {
}

@Override
@SuppressWarnings("unchecked")
protected LogicalPlan rule(UnaryPlan plan) {
LogicalPlan child = plan.child();

Expand Down Expand Up @@ -67,7 +66,7 @@ protected LogicalPlan rule(UnaryPlan plan) {
if (grouping instanceof Attribute attribute) {
groupingAttrs.add(attribute);
} else {
// After applying ReplaceStatsNestedExpressionWithEval, groupings can only contain attributes.
// After applying ReplaceAggregateNestedExpressionWithEval, groupings can only contain attributes.
throw new EsqlIllegalArgumentException("Expected an Attribute, got {}", grouping);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
package org.elasticsearch.xpack.esql.optimizer.rules;

import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.esql.analysis.AnalyzerRules;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Stats;

import java.util.ArrayList;
import java.util.List;

/**
* Removes {@link Stats} overrides in grouping, aggregates and across them inside.
* Removes {@link Aggregate} overrides in grouping, aggregates and across them inside.
* The overrides appear when the same alias is used multiple times in aggregations
* and/or groupings:
* {@code STATS x = COUNT(*), x = MIN(a) BY x = b + 1, x = c + 10}
Expand All @@ -34,26 +33,11 @@
* becomes
* {@code STATS max($x + 1) BY $x = a + b}
*/
public final class RemoveStatsOverride extends AnalyzerRules.AnalyzerRule<LogicalPlan> {
public final class RemoveStatsOverride extends OptimizerRules.OptimizerRule<Aggregate> {

@Override
protected boolean skipResolved() {
return false;
}

@Override
protected LogicalPlan rule(LogicalPlan p) {
if (p.resolved() == false) {
return p;
}
if (p instanceof Stats stats) {
return (LogicalPlan) stats.with(
stats.child(),
removeDuplicateNames(stats.groupings()),
removeDuplicateNames(stats.aggregates())
);
}
return p;
protected LogicalPlan rule(Aggregate aggregate) {
return aggregate.with(removeDuplicateNames(aggregate.groupings()), removeDuplicateNames(aggregate.aggregates()));
}

private static <T extends Expression> List<T> removeDuplicateNames(List<T> list) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
* becomes
* stats a = min(x), c = count(*) by g | eval b = a, d = c | keep a, b, c, d, g
*/
public final class ReplaceStatsAggExpressionWithEval extends OptimizerRules.OptimizerRule<Aggregate> {
public ReplaceStatsAggExpressionWithEval() {
public final class ReplaceAggregateAggExpressionWithEval extends OptimizerRules.OptimizerRule<Aggregate> {
public ReplaceAggregateAggExpressionWithEval() {
super(OptimizerRules.TransformDirection.UP);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Stats;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* Replace nested expressions inside a {@link Stats} with synthetic eval.
* Replace nested expressions inside a {@link Aggregate} with synthetic eval.
* {@code STATS SUM(a + 1) BY x % 2}
* becomes
* {@code EVAL `a + 1` = a + 1, `x % 2` = x % 2 | STATS SUM(`a+1`_ref) BY `x % 2`_ref}
Expand All @@ -34,17 +34,10 @@
* becomes
* {@code EVAL `a + 1` = a + 1, `x % 2` = x % 2 | INLINESTATS SUM(`a+1`_ref) BY `x % 2`_ref}
*/
public final class ReplaceStatsNestedExpressionWithEval extends OptimizerRules.OptimizerRule<LogicalPlan> {
public final class ReplaceAggregateNestedExpressionWithEval extends OptimizerRules.OptimizerRule<Aggregate> {

@Override
protected LogicalPlan rule(LogicalPlan p) {
if (p instanceof Stats stats) {
return rule(stats);
}
return p;
}

private LogicalPlan rule(Stats aggregate) {
protected LogicalPlan rule(Aggregate aggregate) {
List<Alias> evals = new ArrayList<>();
Map<String, Attribute> evalNames = new HashMap<>();
Map<GroupingFunction, Attribute> groupingAttributes = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,19 @@ public PlanFactory visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx) {
return input -> new Aggregate(source(ctx), input, Aggregate.AggregateType.STANDARD, stats.groupings, stats.aggregates);
}

@Override
public PlanFactory visitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx) {
if (false == EsqlPlugin.INLINESTATS_FEATURE_FLAG.isEnabled()) {
throw new ParsingException(source(ctx), "INLINESTATS command currently requires a snapshot build");
}

final Stats stats = stats(source(ctx), ctx.grouping, ctx.stats);
return input -> new InlineStats(
source(ctx),
new Aggregate(source(ctx), input, Aggregate.AggregateType.STANDARD, stats.groupings, stats.aggregates)
);
}

private record Stats(List<Expression> groupings, List<? extends NamedExpression> aggregates) {

}
Expand Down Expand Up @@ -336,17 +349,6 @@ private void fail(Expression exp, String message, Object... args) {
throw new VerificationException(Collections.singletonList(Failure.fail(exp, message, args)));
}

@Override
public PlanFactory visitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx) {
if (false == EsqlPlugin.INLINESTATS_FEATURE_FLAG.isEnabled()) {
throw new ParsingException(source(ctx), "INLINESTATS command currently requires a snapshot build");
}
List<NamedExpression> aggregates = new ArrayList<>(visitFields(ctx.stats));
List<NamedExpression> groupings = visitGrouping(ctx.grouping);
aggregates.addAll(groupings);
return input -> new InlineStats(source(ctx), input, new ArrayList<>(groupings), aggregates);
}

@Override
public PlanFactory visitWhereCommand(EsqlBaseParser.WhereCommandContext ctx) {
Expression expression = expression(ctx.booleanExpression());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import static java.util.Collections.emptyList;
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;

public class Aggregate extends UnaryPlan implements Stats {
public class Aggregate extends UnaryPlan {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
LogicalPlan.class,
"Aggregate",
Expand Down Expand Up @@ -107,7 +107,10 @@ public Aggregate replaceChild(LogicalPlan newChild) {
return new Aggregate(source(), newChild, aggregateType, groupings, aggregates);
}

@Override
public Aggregate with(List<Expression> newGroupings, List<? extends NamedExpression> newAggregates) {
return with(child(), newGroupings, newAggregates);
}

public Aggregate with(LogicalPlan child, List<Expression> newGroupings, List<? extends NamedExpression> newAggregates) {
return new Aggregate(source(), child, aggregateType(), newGroupings, newAggregates);
}
Expand Down
Loading

0 comments on commit 076425d

Please sign in to comment.