Skip to content

Commit

Permalink
Fixes plus basic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
costin committed Sep 23, 2024
1 parent c706569 commit a7bc7ea
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2294,9 +2294,20 @@ m:integer |a:double |x:integer

statsWithFiltering#[skip:-8.16.0,reason:implemented in 8.16]
from employees
| stats max = max(salary), max_f = max(salary) where salary < 5000, min = min(salary), min_f = min(salary) where salary > 5000
| stats max = max(salary), max_f = max(salary) where salary < 50000, max_a = max(salary) where salary > 100,
min = min(salary), min_f = min(salary) where salary > 50000, min_a = min(salary) where salary > 100
;

max:integer |max_f:integer | min:integer | min_f:integer
74999 |50000 | 50000 | 50000
max:integer |max_f:integer |max_a:integer | min:integer | min_f:integer | min_a:integer
74999 |49818 |74999 | 25324 | 50064 | 25324
;

statsWithEverythingFiltered#[skip:-8.16.0,reason:implemented in 8.16]
from employees
| stats max = max(salary), max_a = max(salary) where salary < 100,
min = min(salary), min_a = min(salary) where salary > 99999
;

max:integer |max_a:integer|min:integer | min_a:integer
74999 |null |25324 | null
;
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@ public FilteredExpression(Source source, Expression delegate, Expression filter)
}

public FilteredExpression(StreamInput in) throws IOException {
this(
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(Expression.class),
in.readNamedWriteable(Expression.class)
);
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ protected LogicalPlan rule(Aggregate aggregate) {
if (unwrapped instanceof FilteredExpression fe) {
changed.set(true);
Expression filter = fe.filter();
unwrapped = fe.delegate()
.transformUp(AggregateFunction.class, af -> new FilteredAggregation(af.source(), af, filter));
unwrapped = fe.delegate().transformUp(AggregateFunction.class, af -> new FilteredAggregation(af.source(), af, filter));
as = as.replaceChild(unwrapped);
}
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,7 @@ public LogicalPlan visitFromCommand(EsqlBaseParser.FromCommandContext ctx) {
@Override
public PlanFactory visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx) {
final Stats stats = stats(source(ctx), ctx.grouping, ctx.stats);
return input -> new Aggregate(
source(ctx),
input,
Aggregate.AggregateType.STANDARD,
stats.groupings,
stats.aggregates
);
return input -> new Aggregate(source(ctx), input, Aggregate.AggregateType.STANDARD, stats.groupings, stats.aggregates);
}

private record Stats(List<Expression> groupings, List<? extends NamedExpression> aggregates) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,7 @@ protected AttributeSet computeReferences() {
return computeReferences(aggregates, groupings);
}

public static AttributeSet computeReferences(
List<? extends NamedExpression> aggregates,
List<? extends Expression> groupings
) {
public static AttributeSet computeReferences(List<? extends NamedExpression> aggregates, List<? extends Expression> groupings) {
return Expressions.references(groupings).combine(Expressions.references(aggregates));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,7 @@ public String getWriteableName() {

@Override
protected NodeInfo<AggregateExec> info() {
return NodeInfo.create(
this,
AggregateExec::new,
child(),
groupings,
aggregates,
mode,
intermediateAttributes,
estimatedRowSize
);
return NodeInfo.create(this, AggregateExec::new, child(), groupings, aggregates, mode, intermediateAttributes, estimatedRowSize);
}

@Override
Expand Down Expand Up @@ -191,9 +182,7 @@ public List<Attribute> output() {

@Override
protected AttributeSet computeReferences() {
return mode.isInputPartial()
? new AttributeSet(intermediateAttributes)
: Aggregate.computeReferences(aggregates, groupings);
return mode.isInputPartial() ? new AttributeSet(intermediateAttributes) : Aggregate.computeReferences(aggregates, groupings);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,8 @@ private void aggregatesToFactory(
);
}
} else {
// TODO: this needs rework since the name is confusing and redundant and identical to references
// extra dependencies like TS ones (that require a timestamp)
for (Expression input : aggregateFunction.inputExpressions()) {
for (Expression input : aggregateFunction.references()) {
Attribute attr = Expressions.attribute(input);
if (attr == null) {
throw new EsqlIllegalArgumentException(
Expand All @@ -271,11 +270,6 @@ private void aggregatesToFactory(
}
sourceAttr.add(attr);
}
if (aggregateFunction instanceof FilteredAggregation filteredAgg) {
for (Attribute ref : filteredAgg.filter().references()) {
sourceAttr.add(ref);
}
}
}
}
// coordinator/exchange phase
Expand All @@ -291,27 +285,36 @@ else if (mode == AggregatorMode.FINAL || mode == AggregatorMode.INTERMEDIATE) {
List<Integer> inputChannels = sourceAttr.stream().map(attr -> layout.get(attr.id()).channel()).toList();
assert inputChannels.stream().allMatch(i -> i >= 0) : inputChannels;

if (aggregateFunction instanceof ToAggregator agg) {
AggregatorFunctionSupplier aggSupplier = agg.supplier(inputChannels);
// if a filter is specified, apply wrapping but only on raw data - as the rest of the data is already filtered
// TODO: encapsulate the creation inside the class - requires access to layout
if (mode.isInputPartial() == false && agg instanceof FilteredAggregation filteredAggregate) {
AggregatorFunctionSupplier aggSupplier = null;
if (aggregateFunction instanceof FilteredAggregation filteredAggregate) {
AggregateFunction fa = filteredAggregate.delegate();

aggSupplier = supplier(fa, inputChannels);

// apply the filter only in the initial phase - as the rest of the data is already filtered
if (mode.isInputPartial() == false) {
EvalOperator.ExpressionEvaluator.Factory evalFactory = EvalMapper.toEvaluator(
filteredAggregate.filter(),
layout
);
AggregatorFunctionSupplier delegate = ((FilteredAggregatorFunctionSupplier) aggSupplier).next();
aggSupplier = new FilteredAggregatorFunctionSupplier(delegate, evalFactory);
aggSupplier = new FilteredAggregatorFunctionSupplier(aggSupplier, evalFactory);
}
consumer.accept(new AggFunctionSupplierContext(aggSupplier, mode));
} else {
throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
aggSupplier = supplier(aggregateFunction, inputChannels);
}
consumer.accept(new AggFunctionSupplierContext(aggSupplier, mode));
}
}
}
}

private static AggregatorFunctionSupplier supplier(AggregateFunction aggregateFunction, List<Integer> inputChannels) {
if (aggregateFunction instanceof ToAggregator delegate) {
return delegate.supplier(inputChannels);
}
throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
}

private record GroupSpec(Integer channel, Attribute attribute) {
BlockHash.GroupSpec toHashGroupSpec() {
if (channel == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -67,10 +66,7 @@
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.TestLocalPhysicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.TestPhysicalPlanOptimizer;
import org.elasticsearch.xpack.esql.parser.EsqlParser;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
Expand Down Expand Up @@ -149,7 +145,7 @@
* <p>
* To log the results logResults() should return "true".
*/
@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug")
// @TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug")
public class CsvTests extends ESTestCase {

private static final Logger LOGGER = LogManager.getLogger(CsvTests.class);
Expand All @@ -167,7 +163,6 @@ public class CsvTests extends ESTestCase {
private final EsqlFunctionRegistry functionRegistry = new EsqlFunctionRegistry();
private final EsqlParser parser = new EsqlParser();
private final Mapper mapper = new Mapper(functionRegistry);
private final PhysicalPlanOptimizer physicalPlanOptimizer = new TestPhysicalPlanOptimizer(new PhysicalOptimizerContext(configuration));
private ThreadPool threadPool;
private Executor executor;

Expand Down

0 comments on commit a7bc7ea

Please sign in to comment.