From c70656907d7d3c021b7c86c0bbe073dfb23c6ffb Mon Sep 17 00:00:00 2001 From: Costin Leau Date: Sat, 21 Sep 2024 20:05:45 -0700 Subject: [PATCH] Refactor design by moving filter away from the Aggregate plan node into the aggregate function. This helps keep the filter embedded in the agg which signals to existing rules this aggs shouldn't be optimized or confused with those that do not define a filter. --- .../org/elasticsearch/TransportVersions.java | 1 - .../xpack/esql/analysis/Analyzer.java | 23 +--- .../xpack/esql/analysis/Verifier.java | 6 + .../function/aggregate/AggregateFunction.java | 2 + .../aggregate/FilteredAggregation.java | 105 ++++++++++++++++++ .../aggregate/FilteredExpression.java | 91 +++++++++++++++ .../ReplaceStatsAggExpressionWithEval.java | 24 +++- .../xpack/esql/parser/ExpressionBuilder.java | 50 +++------ .../xpack/esql/parser/LogicalPlanBuilder.java | 17 ++- .../xpack/esql/plan/logical/Aggregate.java | 43 ++----- .../esql/plan/physical/AggregateExec.java | 39 +------ .../AbstractPhysicalOperationProviders.java | 25 ++--- .../xpack/esql/planner/AggregateMapper.java | 4 + .../elasticsearch/xpack/esql/CsvTests.java | 3 +- 14 files changed, 277 insertions(+), 156 deletions(-) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FilteredAggregation.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FilteredExpression.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index f6d6d45df07e5..ad50856c556f7 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -208,7 +208,6 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_AGGREGATE_EXEC_TRACKS_INTERMEDIATE_ATTRS = def(8_738_00_0); public static final TransportVersion CCS_TELEMETRY_STATS = def(8_739_00_0); public static final TransportVersion GLOBAL_RETENTION_TELEMETRY = def(8_740_00_0); - public static final TransportVersion ESQL_PER_AGG_FILTER = def(8_741_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 2d4831291aaa0..5cd8705e86c8e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -65,7 +65,6 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.plan.TableIdentifier; -import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Drop; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; @@ -489,28 +488,8 @@ private LogicalPlan resolveStats(Stats stats, List childrenOutput) { newAggregates.add(agg); } - Aggregate aggregatePlan = stats instanceof Aggregate aggPlan ? aggPlan : null; // TODO: remove this when Stats interface is removed - List newFilters = new ArrayList<>(); - if (aggregatePlan != null) { - // use same resolution for filters - allowing them to use both input and groupings - for (Expression filter : aggregatePlan.filters()) { - Expression resolvedFilter = filter.transformUp( - UnresolvedAttribute.class, - ua -> maybeResolveAttribute(ua, resolvedList) - ); - if (resolvedFilter != filter) { - changed.set(true); - } - newFilters.add(resolvedFilter); - } - } - - stats = changed.get() - ? aggregatePlan != null - ? new Aggregate(stats.source(), stats.child(), aggregatePlan.aggregateType(), groupings, newAggregates, newFilters) - : stats.with(stats.child(), groupings, newAggregates) - : stats; + stats = changed.get() ? stats.with(stats.child(), groupings, newAggregates) : stats; } return (LogicalPlan) stats; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java index 9714d3fce6d9f..bcecd8fbe892f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression; import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg; @@ -301,6 +302,11 @@ private static void checkInvalidNamedExpressionUsage( Set failures, int level ) { + // unwrap filtered expression + if (e instanceof FilteredExpression fe) { + e = fe.delegate(); + // TODO add verification for filter clause + } // found an aggregate, constant or a group, bail out if (e instanceof AggregateFunction af) { af.field().forEachDown(AggregateFunction.class, f -> { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java index f0acac0e9744e..5c6c8b0ae6946 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java @@ -33,6 +33,7 @@ public static List getNamedWriteables() { Avg.ENTRY, Count.ENTRY, CountDistinct.ENTRY, + FilteredAggregation.ENTRY, Max.ENTRY, Median.ENTRY, MedianAbsoluteDeviation.ENTRY, @@ -71,6 +72,7 @@ protected AggregateFunction(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); out.writeNamedWriteable(field); + // FIXME: the arguments need to be serialized as well } public Expression field() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FilteredAggregation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FilteredAggregation.java new file mode 100644 index 0000000000000..4af954cd77d87 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FilteredAggregation.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.FilteredAggregatorFunctionSupplier; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.planner.ToAggregator; + +import java.io.IOException; +import java.util.List; + +import static java.util.Collections.singletonList; + +/** + * Basic wrapper for aggregate functions declared with a nested filter (typically in stats). + * The optimizer uses this class to replace the base aggregations from {@code FilteredExpression} with this + * specialized class which encapsulate the underlying aggregation supplier. + */ +public class FilteredAggregation extends AggregateFunction implements ToAggregator { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "FilteredAggregation", + FilteredAggregation::new + ); + + private final AggregateFunction delegate; + private final Expression filter; + + public FilteredAggregation(Source source, AggregateFunction delegate, Expression filter) { + super(source, delegate, singletonList(filter)); + this.delegate = delegate; + this.filter = filter; + } + + public FilteredAggregation(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + (AggregateFunction) in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeNamedWriteable(delegate); + out.writeNamedWriteable(filter); + } + + @Override + public DataType dataType() { + return delegate.dataType(); + } + + public AggregateFunction delegate() { + return delegate; + } + + public Expression filter() { + return filter; + } + + @Override + public FilteredAggregation replaceChildren(List newChildren) { + return new FilteredAggregation(source(), (AggregateFunction) newChildren.get(0), newChildren.get(1)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, FilteredAggregation::new, delegate, filter); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + public FilteredAggregation with(Expression filter) { + return new FilteredAggregation(source(), delegate, filter); + } + + @Override + public AggregatorFunctionSupplier supplier(List inputChannels) { + if (delegate instanceof ToAggregator toAggregator) { + AggregatorFunctionSupplier aggSupplier = toAggregator.supplier(inputChannels); + return new FilteredAggregatorFunctionSupplier(aggSupplier, null); + } else { + throw new EsqlIllegalArgumentException("Cannot create aggregator for " + delegate); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FilteredExpression.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FilteredExpression.java new file mode 100644 index 0000000000000..ab3c251ebc276 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FilteredExpression.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Nullability; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; + +import java.io.IOException; +import java.util.List; + +import static java.util.Arrays.asList; + +/** + * Basic wrapper for expressions declared with a nested filter (typically in stats). + */ +public class FilteredExpression extends Expression { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "FilteredExpression", + FilteredExpression::new + ); + + private final Expression delegate; + private final Expression filter; + + public FilteredExpression(Source source, Expression delegate, Expression filter) { + super(source, asList(delegate, filter)); + this.delegate = delegate; + this.filter = filter; + } + + public FilteredExpression(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeNamedWriteable(delegate); + out.writeNamedWriteable(filter); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + public Expression delegate() { + return delegate; + } + + public Expression filter() { + return filter; + } + + @Override + public DataType dataType() { + return delegate.dataType(); + } + + @Override + public Nullability nullable() { + return delegate.nullable(); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, FilteredExpression::new, delegate, filter); + } + + @Override + public Expression replaceChildren(List newChildren) { + return new FilteredExpression(source(), newChildren.get(0), newChildren.get(1)); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStatsAggExpressionWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStatsAggExpressionWithEval.java index d74811518624a..0ba8fc75ee821 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStatsAggExpressionWithEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStatsAggExpressionWithEval.java @@ -16,6 +16,8 @@ import org.elasticsearch.xpack.esql.core.util.CollectionUtils; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredAggregation; +import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; @@ -69,11 +71,25 @@ protected LogicalPlan rule(Aggregate aggregate) { Holder changed = new Holder<>(false); int[] counter = new int[] { 0 }; - for (NamedExpression agg : aggs) { + for (int i = 0, s = aggs.size(); i < s; i++) { + NamedExpression agg = aggs.get(i); + if (agg instanceof Alias as) { - // if the child a nested expression - Expression child = as.child(); + // use intermediate variable to mark child as final for lambda use + Expression unwrapped = as.child(); + + // unwrap any filtered expression by pushing down the filter into the aggregate function; the rest is returned as is + if (unwrapped instanceof FilteredExpression fe) { + changed.set(true); + Expression filter = fe.filter(); + unwrapped = fe.delegate() + .transformUp(AggregateFunction.class, af -> new FilteredAggregation(af.source(), af, filter)); + as = as.replaceChild(unwrapped); + } + // + final Expression child = unwrapped; + // // common case - handle duplicates if (child instanceof AggregateFunction af) { AggregateFunction canonical = (AggregateFunction) af.canonical(); @@ -130,7 +146,7 @@ protected LogicalPlan rule(Aggregate aggregate) { LogicalPlan plan = aggregate; if (changed.get()) { Source source = aggregate.source(); - plan = new Aggregate(source, aggregate.child(), aggregate.aggregateType(), aggregate.groupings(), newAggs); + plan = aggregate.with(aggregate.child(), aggregate.groupings(), newAggs); if (newEvals.size() > 0) { plan = new Eval(source, plan, newEvals); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java index ae9da35a2384c..093cd9e81b76a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; import org.elasticsearch.xpack.esql.core.expression.UnresolvedStar; +import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.core.expression.predicate.fulltext.MatchQueryPredicate; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; @@ -44,6 +45,7 @@ import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.expression.function.FunctionResolutionStrategy; import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction; +import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression; import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; @@ -673,56 +675,32 @@ public List visitFields(EsqlBaseParser.FieldsContext ctx) { } @Override - public AggField visitAggField(EsqlBaseParser.AggFieldContext ctx) { + public NamedExpression visitAggField(EsqlBaseParser.AggFieldContext ctx) { + Source source = source(ctx); Alias field = visitField(ctx.field()); var filterExpression = ctx.booleanExpression(); Expression condition = filterExpression != null ? expression(filterExpression) : null; if (condition != null) { + Expression child = field.child(); + // basic check as the filter can be specified only on a function (should be an aggregate but we can't determine that yet) + if (field.child().anyMatch(Function.class::isInstance)) { + field = field.replaceChild(new FilteredExpression(source, child, condition)); + } // allow condition only per aggregated function - if (field.child() instanceof UnresolvedFunction uf == false) { - Source source = source(filterExpression); + else { + source = source(filterExpression); throw new ParsingException(source, "WHERE clause allowed only for aggregate functions [{}]", source.text()); } } - return new AggField(field, condition); + return field; } @Override - public AggFields visitAggFields(EsqlBaseParser.AggFieldsContext ctx) { - if (ctx == null) { - return new AggFields(emptyList(), emptyList()); - } - - List aggFields = visitList(this, ctx.aggField(), AggField.class); - - List aggregates = new ArrayList<>(aggFields.size()); - List conditions = new ArrayList<>(aggFields.size()); - - boolean noConditionDeclared = true; - for (AggField aggField : aggFields) { - aggregates.add(aggField.alias()); - Expression condition = aggField.condition(); - if (condition != null) { - noConditionDeclared = false; - } else { - // out of band value to signal no condition was specified - condition = Literal.NULL; - } - conditions.add(condition); - } - - if (noConditionDeclared) { - conditions = emptyList(); - } - - return new AggFields(aggregates, conditions); + public List visitAggFields(EsqlBaseParser.AggFieldsContext ctx) { + return ctx != null ? visitList(this, ctx.aggField(), Alias.class) : new ArrayList<>(); } - record AggFields(List aggregates, List conditions) {} - - private record AggField(Alias alias, Expression condition) {} - /** * Similar to {@link #visitFields(EsqlBaseParser.FieldsContext)} however avoids wrapping the expression * into an Alias. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java index cfea0f6487b6b..58f7dd1e50598 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java @@ -301,17 +301,15 @@ public PlanFactory visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx) { input, Aggregate.AggregateType.STANDARD, stats.groupings, - stats.aggregates, - stats.conditions + stats.aggregates ); } - private record Stats(List groupings, List aggregates, List conditions) {} + private record Stats(List groupings, List aggregates) {} private Stats stats(Source source, EsqlBaseParser.FieldsContext groupingsCtx, EsqlBaseParser.AggFieldsContext aggregatesCtx) { List groupings = visitGrouping(groupingsCtx); - AggFields aggFields = visitAggFields(aggregatesCtx); - List aggregates = aggFields.aggregates(); + List aggregates = new ArrayList<>(visitAggFields(aggregatesCtx)); if (aggregates.isEmpty() && groupings.isEmpty()) { throw new ParsingException(source, "At least one aggregation or grouping expression required in [{}]", source.text()); @@ -333,12 +331,11 @@ private Stats stats(Source source, EsqlBaseParser.FieldsContext groupingsCtx, Es } } } - List conditions = aggFields.conditions(); // since groupings are aliased, add refs to it in the aggregates for (Expression group : groupings) { aggregates.add(Expressions.attribute(group)); } - return new Stats(new ArrayList<>(groupings), aggregates, conditions); + return new Stats(new ArrayList<>(groupings), aggregates); } private void fail(Expression exp, String message, Object... args) { @@ -350,11 +347,11 @@ public PlanFactory visitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandCont if (false == EsqlPlugin.INLINESTATS_FEATURE_FLAG.isEnabled()) { throw new ParsingException(source(ctx), "INLINESTATS command currently requires a snapshot build"); } - AggFields aggFields = visitAggFields(ctx.stats); - List aggregates = aggFields.aggregates(); + List aggFields = visitAggFields(ctx.stats); + List aggregates = new ArrayList<>(aggFields); List groupings = visitGrouping(ctx.grouping); aggregates.addAll(groupings); - // TODO: add support for conditions + // TODO: add support for filters return input -> new InlineStats(source(ctx), input, new ArrayList<>(groupings), aggregates); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java index cc8e53c2d547a..08232e12d8456 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java @@ -59,7 +59,6 @@ static AggregateType readType(StreamInput in) throws IOException { private final AggregateType aggregateType; private final List groupings; private final List aggregates; - private final List filters; private List lazyOutput; @@ -69,23 +68,11 @@ public Aggregate( AggregateType aggregateType, List groupings, List aggregates - ) { - this(source, child, aggregateType, groupings, aggregates, emptyList()); - } - - public Aggregate( - Source source, - LogicalPlan child, - AggregateType aggregateType, - List groupings, - List aggregates, - List filters ) { super(source, child); this.aggregateType = aggregateType; this.groupings = groupings; this.aggregates = aggregates; - this.filters = filters; } public Aggregate(StreamInput in) throws IOException { @@ -94,10 +81,7 @@ public Aggregate(StreamInput in) throws IOException { in.readNamedWriteable(LogicalPlan.class), AggregateType.readType(in), in.readNamedWriteableCollectionAsList(Expression.class), - in.readNamedWriteableCollectionAsList(NamedExpression.class), - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGG_FILTER) - ? in.readNamedWriteableCollectionAsList(Expression.class) - : emptyList() + in.readNamedWriteableCollectionAsList(NamedExpression.class) ); } @@ -108,9 +92,6 @@ public void writeTo(StreamOutput out) throws IOException { AggregateType.writeType(out, aggregateType()); out.writeNamedWriteableCollection(groupings); out.writeNamedWriteableCollection(aggregates()); - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGG_FILTER)) { - out.writeNamedWriteableCollection(filters); - } } @Override @@ -120,17 +101,17 @@ public String getWriteableName() { @Override protected NodeInfo info() { - return NodeInfo.create(this, Aggregate::new, child(), aggregateType, groupings, aggregates, filters); + return NodeInfo.create(this, Aggregate::new, child(), aggregateType, groupings, aggregates); } @Override public Aggregate replaceChild(LogicalPlan newChild) { - return new Aggregate(source(), newChild, aggregateType, groupings, aggregates, filters); + return new Aggregate(source(), newChild, aggregateType, groupings, aggregates); } @Override public Aggregate with(LogicalPlan child, List newGroupings, List newAggregates) { - return new Aggregate(source(), child, aggregateType(), newGroupings, newAggregates, filters); + return new Aggregate(source(), child, aggregateType(), newGroupings, newAggregates); } public AggregateType aggregateType() { @@ -145,10 +126,6 @@ public List aggregates() { return aggregates; } - public List filters() { - return filters; - } - @Override public String commandName() { return switch (aggregateType) { @@ -159,7 +136,7 @@ public String commandName() { @Override public boolean expressionsResolved() { - return Resolvables.resolved(groupings) && Resolvables.resolved(aggregates) && Resolvables.resolved(filters); + return Resolvables.resolved(groupings) && Resolvables.resolved(aggregates); } @Override @@ -176,20 +153,19 @@ public static List output(List aggregates) @Override protected AttributeSet computeReferences() { - return computeReferences(aggregates, groupings, filters); + return computeReferences(aggregates, groupings); } public static AttributeSet computeReferences( List aggregates, - List groupings, - List filters + List groupings ) { - return Expressions.references(groupings).combine(Expressions.references(aggregates).combine(Expressions.references(filters))); + return Expressions.references(groupings).combine(Expressions.references(aggregates)); } @Override public int hashCode() { - return Objects.hash(aggregateType, groupings, aggregates, filters, child()); + return Objects.hash(aggregateType, groupings, aggregates, child()); } @Override @@ -206,7 +182,6 @@ public boolean equals(Object obj) { return aggregateType == other.aggregateType && Objects.equals(groupings, other.groupings) && Objects.equals(aggregates, other.aggregates) - && Objects.equals(filters, other.filters) && Objects.equals(child(), other.child()); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java index 71d72b425ca28..87051fae3da13 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java @@ -26,8 +26,6 @@ import java.util.List; import java.util.Objects; -import static java.util.Collections.emptyList; - public class AggregateExec extends UnaryExec implements EstimatesRowSize { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( PhysicalPlan.class, @@ -37,7 +35,6 @@ public class AggregateExec extends UnaryExec implements EstimatesRowSize { private final List groupings; private final List aggregates; - private final List filters; /** * The output attributes of {@link AggregatorMode#INITIAL} and {@link AggregatorMode#INTERMEDIATE} aggregations, resp. * the input attributes of {@link AggregatorMode#FINAL} and {@link AggregatorMode#INTERMEDIATE} aggregations. @@ -60,24 +57,10 @@ public AggregateExec( AggregatorMode mode, List intermediateAttributes, Integer estimatedRowSize - ) { - this(source, child, groupings, aggregates, emptyList(), mode, intermediateAttributes, estimatedRowSize); - } - - public AggregateExec( - Source source, - PhysicalPlan child, - List groupings, - List aggregates, - List filters, - AggregatorMode mode, - List intermediateAttributes, - Integer estimatedRowSize ) { super(source, child); this.groupings = groupings; this.aggregates = aggregates; - this.filters = filters; this.mode = mode; this.intermediateAttributes = intermediateAttributes; this.estimatedRowSize = estimatedRowSize; @@ -91,10 +74,6 @@ private AggregateExec(StreamInput in) throws IOException { ((PlanStreamInput) in).readPhysicalPlanNode(), in.readNamedWriteableCollectionAsList(Expression.class), in.readNamedWriteableCollectionAsList(NamedExpression.class), - // filters - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGG_FILTER) - ? in.readNamedWriteableCollectionAsList(Expression.class) - : emptyList(), in.readEnum(AggregatorMode.class), in.readNamedWriteableCollectionAsList(Attribute.class), in.readOptionalVInt() @@ -107,10 +86,6 @@ public void writeTo(StreamOutput out) throws IOException { ((PlanStreamOutput) out).writePhysicalPlanNode(child()); out.writeNamedWriteableCollection(groupings()); out.writeNamedWriteableCollection(aggregates()); - // filters - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGG_FILTER)) { - out.writeNamedWriteableCollection(filters); - } if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_AGGREGATE_EXEC_TRACKS_INTERMEDIATE_ATTRS)) { out.writeEnum(getMode()); out.writeNamedWriteableCollection(intermediateAttributes()); @@ -133,7 +108,6 @@ protected NodeInfo info() { child(), groupings, aggregates, - filters, mode, intermediateAttributes, estimatedRowSize @@ -142,7 +116,7 @@ protected NodeInfo info() { @Override public AggregateExec replaceChild(PhysicalPlan newChild) { - return new AggregateExec(source(), newChild, groupings, aggregates, filters, mode, intermediateAttributes, estimatedRowSize); + return new AggregateExec(source(), newChild, groupings, aggregates, mode, intermediateAttributes, estimatedRowSize); } public List groupings() { @@ -153,12 +127,8 @@ public List aggregates() { return aggregates; } - public List filters() { - return filters; - } - public AggregateExec withMode(AggregatorMode newMode) { - return new AggregateExec(source(), child(), groupings, aggregates, filters, newMode, intermediateAttributes, estimatedRowSize); + return new AggregateExec(source(), child(), groupings, aggregates, newMode, intermediateAttributes, estimatedRowSize); } /** @@ -175,7 +145,7 @@ public PhysicalPlan estimateRowSize(State state) { int size = state.consumeAllFields(true); return Objects.equals(this.estimatedRowSize, size) ? this - : new AggregateExec(source(), child(), groupings, aggregates, filters, mode, intermediateAttributes, size); + : new AggregateExec(source(), child(), groupings, aggregates, mode, intermediateAttributes, size); } public AggregatorMode getMode() { @@ -223,7 +193,7 @@ public List output() { protected AttributeSet computeReferences() { return mode.isInputPartial() ? new AttributeSet(intermediateAttributes) - : Aggregate.computeReferences(aggregates, groupings, filters); + : Aggregate.computeReferences(aggregates, groupings); } @Override @@ -244,7 +214,6 @@ public boolean equals(Object obj) { AggregateExec other = (AggregateExec) obj; return Objects.equals(groupings, other.groupings) && Objects.equals(aggregates, other.aggregates) - && Objects.equals(filters, other.filters) && Objects.equals(mode, other.mode) && Objects.equals(intermediateAttributes, other.intermediateAttributes) && Objects.equals(estimatedRowSize, other.estimatedRowSize) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java index d003227e156f8..44651451eae7b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java @@ -24,12 +24,12 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; -import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.evaluator.EvalMapper; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredAggregation; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext; @@ -82,7 +82,6 @@ public final PhysicalOperation groupingPhysicalOperation( // create the agg factories aggregatesToFactory( aggregates, - aggregateExec.filters(), aggregatorMode, sourceLayout, false, // non-grouping @@ -148,7 +147,6 @@ else if (aggregatorMode == AggregatorMode.INITIAL || aggregatorMode == Aggregato // create the agg factories aggregatesToFactory( aggregates, - aggregateExec.filters(), aggregatorMode, sourceLayout, true, // grouping @@ -232,18 +230,14 @@ private record AggFunctionSupplierContext(AggregatorFunctionSupplier supplier, A private void aggregatesToFactory( List aggregates, - List filters, AggregatorMode mode, Layout layout, boolean grouping, Consumer consumer ) { // extract filtering channels - and wrap the aggregation with the new evaluator expression only during the init phase - for (int index = 0, s = aggregates.size(); index < s; index++) { - NamedExpression ne = aggregates.get(index); + for (NamedExpression ne : aggregates) { // the filter is missing for groups - Expression filter = index < filters.size() ? filters.get(index) : null; - boolean hasFilter = filter != null && filter != Literal.NULL; if (ne instanceof Alias alias) { var child = alias.child(); @@ -277,8 +271,8 @@ private void aggregatesToFactory( } sourceAttr.add(attr); } - if (hasFilter) { - for (Attribute ref : filter.references()) { + if (aggregateFunction instanceof FilteredAggregation filteredAgg) { + for (Attribute ref : filteredAgg.filter().references()) { sourceAttr.add(ref); } } @@ -300,9 +294,14 @@ else if (mode == AggregatorMode.FINAL || mode == AggregatorMode.INTERMEDIATE) { if (aggregateFunction instanceof ToAggregator agg) { AggregatorFunctionSupplier aggSupplier = agg.supplier(inputChannels); // if a filter is specified, apply wrapping but only on raw data - as the rest of the data is already filtered - if (mode.isInputPartial() == false && hasFilter) { - EvalOperator.ExpressionEvaluator.Factory evalFactory = EvalMapper.toEvaluator(filter, layout); - aggSupplier = new FilteredAggregatorFunctionSupplier(aggSupplier, evalFactory); + // TODO: encapsulate the creation inside the class - requires access to layout + if (mode.isInputPartial() == false && agg instanceof FilteredAggregation filteredAggregate) { + EvalOperator.ExpressionEvaluator.Factory evalFactory = EvalMapper.toEvaluator( + filteredAggregate.filter(), + layout + ); + AggregatorFunctionSupplier delegate = ((FilteredAggregatorFunctionSupplier) aggSupplier).next(); + aggSupplier = new FilteredAggregatorFunctionSupplier(delegate, evalFactory); } consumer.accept(new AggFunctionSupplierContext(aggSupplier, mode)); } else { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index 5573346b5e29e..80a4737ca4676 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct; +import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredAggregation; import org.elasticsearch.xpack.esql.expression.function.aggregate.FromPartial; import org.elasticsearch.xpack.esql.expression.function.aggregate.Max; import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation; @@ -202,6 +203,9 @@ private static Stream groupingAndNonGrouping(Tuple, Tuple * To log the results logResults() should return "true". */ -// @TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug") +@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug") public class CsvTests extends ESTestCase { private static final Logger LOGGER = LogManager.getLogger(CsvTests.class);