Skip to content

Commit

Permalink
Refactor design by moving filter away from the Aggregate plan node into
Browse files Browse the repository at this point in the history
 the aggregate function. This helps keep the filter embedded in the agg
 which signals to existing rules this aggs shouldn't be optimized or
 confused with those that do not define a filter.
  • Loading branch information
costin committed Sep 22, 2024
1 parent 2785a29 commit c706569
Show file tree
Hide file tree
Showing 14 changed files with 277 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_AGGREGATE_EXEC_TRACKS_INTERMEDIATE_ATTRS = def(8_738_00_0);
public static final TransportVersion CCS_TELEMETRY_STATS = def(8_739_00_0);
public static final TransportVersion GLOBAL_RETENTION_TELEMETRY = def(8_740_00_0);
public static final TransportVersion ESQL_PER_AGG_FILTER = def(8_741_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.plan.TableIdentifier;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Drop;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
Expand Down Expand Up @@ -489,28 +488,8 @@ private LogicalPlan resolveStats(Stats stats, List<Attribute> childrenOutput) {
newAggregates.add(agg);
}

Aggregate aggregatePlan = stats instanceof Aggregate aggPlan ? aggPlan : null;
// TODO: remove this when Stats interface is removed
List<Expression> newFilters = new ArrayList<>();
if (aggregatePlan != null) {
// use same resolution for filters - allowing them to use both input and groupings
for (Expression filter : aggregatePlan.filters()) {
Expression resolvedFilter = filter.transformUp(
UnresolvedAttribute.class,
ua -> maybeResolveAttribute(ua, resolvedList)
);
if (resolvedFilter != filter) {
changed.set(true);
}
newFilters.add(resolvedFilter);
}
}

stats = changed.get()
? aggregatePlan != null
? new Aggregate(stats.source(), stats.child(), aggregatePlan.aggregateType(), groupings, newAggregates, newFilters)
: stats.with(stats.child(), groupings, newAggregates)
: stats;
stats = changed.get() ? stats.with(stats.child(), groupings, newAggregates) : stats;
}

return (LogicalPlan) stats;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg;
Expand Down Expand Up @@ -301,6 +302,11 @@ private static void checkInvalidNamedExpressionUsage(
Set<Failure> failures,
int level
) {
// unwrap filtered expression
if (e instanceof FilteredExpression fe) {
e = fe.delegate();
// TODO add verification for filter clause
}
// found an aggregate, constant or a group, bail out
if (e instanceof AggregateFunction af) {
af.field().forEachDown(AggregateFunction.class, f -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
Avg.ENTRY,
Count.ENTRY,
CountDistinct.ENTRY,
FilteredAggregation.ENTRY,
Max.ENTRY,
Median.ENTRY,
MedianAbsoluteDeviation.ENTRY,
Expand Down Expand Up @@ -71,6 +72,7 @@ protected AggregateFunction(StreamInput in) throws IOException {
public void writeTo(StreamOutput out) throws IOException {
Source.EMPTY.writeTo(out);
out.writeNamedWriteable(field);
// FIXME: the arguments need to be serialized as well
}

public Expression field() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.aggregate;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.FilteredAggregatorFunctionSupplier;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.planner.ToAggregator;

import java.io.IOException;
import java.util.List;

import static java.util.Collections.singletonList;

/**
* Basic wrapper for aggregate functions declared with a nested filter (typically in stats).
* The optimizer uses this class to replace the base aggregations from {@code FilteredExpression} with this
* specialized class which encapsulate the underlying aggregation supplier.
*/
public class FilteredAggregation extends AggregateFunction implements ToAggregator {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Expression.class,
"FilteredAggregation",
FilteredAggregation::new
);

private final AggregateFunction delegate;
private final Expression filter;

public FilteredAggregation(Source source, AggregateFunction delegate, Expression filter) {
super(source, delegate, singletonList(filter));
this.delegate = delegate;
this.filter = filter;
}

public FilteredAggregation(StreamInput in) throws IOException {
this(
Source.readFrom((PlanStreamInput) in),
(AggregateFunction) in.readNamedWriteable(Expression.class),
in.readNamedWriteable(Expression.class)
);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Source.EMPTY.writeTo(out);
out.writeNamedWriteable(delegate);
out.writeNamedWriteable(filter);
}

@Override
public DataType dataType() {
return delegate.dataType();
}

public AggregateFunction delegate() {
return delegate;
}

public Expression filter() {
return filter;
}

@Override
public FilteredAggregation replaceChildren(List<Expression> newChildren) {
return new FilteredAggregation(source(), (AggregateFunction) newChildren.get(0), newChildren.get(1));
}

@Override
protected NodeInfo<FilteredAggregation> info() {
return NodeInfo.create(this, FilteredAggregation::new, delegate, filter);
}

@Override
public String getWriteableName() {
return ENTRY.name;
}

public FilteredAggregation with(Expression filter) {
return new FilteredAggregation(source(), delegate, filter);
}

@Override
public AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
if (delegate instanceof ToAggregator toAggregator) {
AggregatorFunctionSupplier aggSupplier = toAggregator.supplier(inputChannels);
return new FilteredAggregatorFunctionSupplier(aggSupplier, null);
} else {
throw new EsqlIllegalArgumentException("Cannot create aggregator for " + delegate);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.aggregate;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;

import java.io.IOException;
import java.util.List;

import static java.util.Arrays.asList;

/**
* Basic wrapper for expressions declared with a nested filter (typically in stats).
*/
public class FilteredExpression extends Expression {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Expression.class,
"FilteredExpression",
FilteredExpression::new
);

private final Expression delegate;
private final Expression filter;

public FilteredExpression(Source source, Expression delegate, Expression filter) {
super(source, asList(delegate, filter));
this.delegate = delegate;
this.filter = filter;
}

public FilteredExpression(StreamInput in) throws IOException {
this(
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(Expression.class),
in.readNamedWriteable(Expression.class)
);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Source.EMPTY.writeTo(out);
out.writeNamedWriteable(delegate);
out.writeNamedWriteable(filter);
}

@Override
public String getWriteableName() {
return ENTRY.name;
}

public Expression delegate() {
return delegate;
}

public Expression filter() {
return filter;
}

@Override
public DataType dataType() {
return delegate.dataType();
}

@Override
public Nullability nullable() {
return delegate.nullable();
}

@Override
protected NodeInfo<FilteredExpression> info() {
return NodeInfo.create(this, FilteredExpression::new, delegate, filter);
}

@Override
public Expression replaceChildren(List<Expression> newChildren) {
return new FilteredExpression(source(), newChildren.get(0), newChildren.get(1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredAggregation;
import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
Expand Down Expand Up @@ -69,11 +71,25 @@ protected LogicalPlan rule(Aggregate aggregate) {
Holder<Boolean> changed = new Holder<>(false);
int[] counter = new int[] { 0 };

for (NamedExpression agg : aggs) {
for (int i = 0, s = aggs.size(); i < s; i++) {
NamedExpression agg = aggs.get(i);

if (agg instanceof Alias as) {
// if the child a nested expression
Expression child = as.child();
// use intermediate variable to mark child as final for lambda use
Expression unwrapped = as.child();

// unwrap any filtered expression by pushing down the filter into the aggregate function; the rest is returned as is
if (unwrapped instanceof FilteredExpression fe) {
changed.set(true);
Expression filter = fe.filter();
unwrapped = fe.delegate()
.transformUp(AggregateFunction.class, af -> new FilteredAggregation(af.source(), af, filter));
as = as.replaceChild(unwrapped);
}
//
final Expression child = unwrapped;

//
// common case - handle duplicates
if (child instanceof AggregateFunction af) {
AggregateFunction canonical = (AggregateFunction) af.canonical();
Expand Down Expand Up @@ -130,7 +146,7 @@ protected LogicalPlan rule(Aggregate aggregate) {
LogicalPlan plan = aggregate;
if (changed.get()) {
Source source = aggregate.source();
plan = new Aggregate(source, aggregate.child(), aggregate.aggregateType(), aggregate.groupings(), newAggs);
plan = aggregate.with(aggregate.child(), aggregate.groupings(), newAggs);
if (newEvals.size() > 0) {
plan = new Eval(source, plan, newEvals);
}
Expand Down
Loading

0 comments on commit c706569

Please sign in to comment.