Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ESQL: Compute engine support for stateful grouping functions #112757

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import org.elasticsearch.compute.aggregation.CountAggregatorFunction;
import org.elasticsearch.compute.aggregation.CountDistinctDoubleAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.CountDistinctLongAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.GroupingKey;
import org.elasticsearch.compute.aggregation.MaxDoubleAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.MinDoubleAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.MinLongAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BooleanBlock;
Expand Down Expand Up @@ -125,29 +125,32 @@ private static Operator operator(DriverContext driverContext, String grouping, S
driverContext
);
}
List<BlockHash.GroupSpec> groups = switch (grouping) {
case LONGS -> List.of(new BlockHash.GroupSpec(0, ElementType.LONG));
case INTS -> List.of(new BlockHash.GroupSpec(0, ElementType.INT));
case DOUBLES -> List.of(new BlockHash.GroupSpec(0, ElementType.DOUBLE));
case BOOLEANS -> List.of(new BlockHash.GroupSpec(0, ElementType.BOOLEAN));
case BYTES_REFS -> List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF));
case TWO_LONGS -> List.of(new BlockHash.GroupSpec(0, ElementType.LONG), new BlockHash.GroupSpec(1, ElementType.LONG));
List<GroupingKey.Factory> groups = switch (grouping) {
case LONGS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(AggregatorMode.SINGLE));
case INTS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.INT).get(AggregatorMode.SINGLE));
case DOUBLES -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.DOUBLE).get(AggregatorMode.SINGLE));
case BOOLEANS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.BOOLEAN).get(AggregatorMode.SINGLE));
case BYTES_REFS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.BYTES_REF).get(AggregatorMode.SINGLE));
case TWO_LONGS -> List.of(
GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(AggregatorMode.SINGLE),
GroupingKey.forStatelessGrouping(1, ElementType.LONG).get(AggregatorMode.SINGLE)
);
case LONGS_AND_BYTES_REFS -> List.of(
new BlockHash.GroupSpec(0, ElementType.LONG),
new BlockHash.GroupSpec(1, ElementType.BYTES_REF)
GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(AggregatorMode.SINGLE),
GroupingKey.forStatelessGrouping(1, ElementType.BYTES_REF).get(AggregatorMode.SINGLE)
);
case TWO_LONGS_AND_BYTES_REFS -> List.of(
new BlockHash.GroupSpec(0, ElementType.LONG),
new BlockHash.GroupSpec(1, ElementType.LONG),
new BlockHash.GroupSpec(2, ElementType.BYTES_REF)
GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(AggregatorMode.SINGLE),
GroupingKey.forStatelessGrouping(1, ElementType.LONG).get(AggregatorMode.SINGLE),
GroupingKey.forStatelessGrouping(2, ElementType.BYTES_REF).get(AggregatorMode.SINGLE)
);
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
};
return new HashAggregationOperator(
return new HashAggregationOperator.HashAggregationOperatorFactory(
groups,
List.of(supplier(op, dataType, groups.size()).groupingAggregatorFactory(AggregatorMode.SINGLE)),
() -> BlockHash.build(groups, driverContext.blockFactory(), 16 * 1024, false),
driverContext
);
16 * 1024
).get(driverContext);
}

private static AggregatorFunctionSupplier supplier(String op, String dataType, int dataChannel) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.aggregation;

import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.core.Releasable;

import java.util.ArrayList;
import java.util.List;

public record GroupingKey(AggregatorMode mode, Thing thing, BlockFactory blockFactory) implements EvalOperator.ExpressionEvaluator {
public interface Thing extends Releasable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "Thing" the final name here?

int extraIntermediateBlocks();

Block evalRawInput(Page page);

Block evalIntermediateInput(BlockFactory blockFactory, Page page);

void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int positionCount);

void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks);
}

public interface Supplier {
Factory get(AggregatorMode mode);
}

public interface Factory {
GroupingKey apply(DriverContext context, int resultOffset);

ElementType intermediateElementType();

GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int timeBucketChannel);
}

public static GroupingKey.Supplier forStatelessGrouping(int channel, ElementType elementType) {
return mode -> new Factory() {
@Override
public GroupingKey apply(DriverContext context, int resultOffset) {
return new GroupingKey(mode, new Load(channel, resultOffset), context.blockFactory());
}

@Override
public ElementType intermediateElementType() {
return elementType;
}

@Override
public GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int timeBucketChannel) {
if (channel != timeBucketChannel) {
final List<Integer> channels = List.of(channel);
// TODO: perhaps introduce a specialized aggregator for this?
return (switch (intermediateElementType()) {
case BYTES_REF -> new ValuesBytesRefAggregatorFunctionSupplier(channels);
case DOUBLE -> new ValuesDoubleAggregatorFunctionSupplier(channels);
case INT -> new ValuesIntAggregatorFunctionSupplier(channels);
case LONG -> new ValuesLongAggregatorFunctionSupplier(channels);
case BOOLEAN -> new ValuesBooleanAggregatorFunctionSupplier(channels);
case FLOAT, NULL, DOC, COMPOSITE, UNKNOWN -> throw new IllegalArgumentException("unsupported grouping type");
}).groupingAggregatorFactory(AggregatorMode.SINGLE);
}
return null;
}
};
}

public static List<BlockHash.GroupSpec> toBlockHashGroupSpec(List<GroupingKey.Factory> keys) {
List<BlockHash.GroupSpec> result = new ArrayList<>(keys.size());
for (int k = 0; k < keys.size(); k++) {
result.add(new BlockHash.GroupSpec(k, keys.get(k).intermediateElementType()));
}
return result;
}

@Override
public Block eval(Page page) {
return mode.isInputPartial() ? thing.evalIntermediateInput(blockFactory, page) : thing.evalRawInput(page);
}

public int finishBlockCount() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: We're calling this "finish", while in the aggregator it's "evaluate". Some reason to keep those names separated? From what I understand, the operation is nearly the same (?)

return mode.isOutputPartial() ? 1 + thing.extraIntermediateBlocks() : 1;
}

public void finish(Block[] blocks, IntVector selected, DriverContext driverContext) {
if (mode.isOutputPartial()) {
thing.fetchIntermediateState(driverContext.blockFactory(), blocks, selected.getPositionCount());
} else {
thing.replaceIntermediateKeys(driverContext.blockFactory(), blocks);
}
}

public int extraIntermediateBlocks() {
return thing.extraIntermediateBlocks();
}

@Override
public void close() {
thing.close();
}

private record Load(int channel, int resultOffset) implements Thing {
@Override
public int extraIntermediateBlocks() {
return 0;
}

@Override
public Block evalRawInput(Page page) {
Block b = page.getBlock(channel);
b.incRef();
return b;
}

@Override
public Block evalIntermediateInput(BlockFactory blockFactory, Page page) {
Block b = page.getBlock(resultOffset);
b.incRef();
return b;
}

@Override
public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int positionCount) {}

@Override
public void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks) {}

@Override
public void close() {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.compute.Describable;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.GroupingKey;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.IntBlock;
Expand All @@ -27,26 +28,24 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;

public class HashAggregationOperator implements Operator {

public record HashAggregationOperatorFactory(
List<BlockHash.GroupSpec> groups,
List<GroupingKey.Factory> groups,
List<GroupingAggregator.Factory> aggregators,
int maxPageSize
) implements OperatorFactory {
@Override
public Operator get(DriverContext driverContext) {
return new HashAggregationOperator(
aggregators,
() -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false),
groups,
() -> BlockHash.build(GroupingKey.toBlockHashGroupSpec(groups), driverContext.blockFactory(), maxPageSize, false),
driverContext
);
}
Expand All @@ -61,15 +60,17 @@ public String describe() {
}
}

private boolean finished;
private Page output;
private final List<GroupingAggregator> aggregators;

private final BlockHash blockHash;
private final List<GroupingKey> groups;

private final List<GroupingAggregator> aggregators;
private final BlockHash blockHash;

private final DriverContext driverContext;

private boolean finished;
private Page output;

/**
* Nanoseconds this operator has spent hashing grouping keys.
*/
Expand All @@ -86,17 +87,25 @@ public String describe() {
@SuppressWarnings("this-escape")
public HashAggregationOperator(
List<GroupingAggregator.Factory> aggregators,
List<GroupingKey.Factory> groups,
Supplier<BlockHash> blockHash,
DriverContext driverContext
) {
this.aggregators = new ArrayList<>(aggregators.size());
this.groups = new ArrayList<>(groups.size());
this.driverContext = driverContext;
boolean success = false;
try {
this.blockHash = blockHash.get();
for (GroupingAggregator.Factory a : aggregators) {
this.aggregators.add(a.apply(driverContext));
}
int offset = 0;
for (GroupingKey.Factory g : groups) {
GroupingKey key = g.apply(driverContext, offset);
this.groups.add(key);
offset += key.extraIntermediateBlocks() + 1;
}
success = true;
} finally {
if (success == false) {
Expand All @@ -112,6 +121,8 @@ public boolean needsInput() {

@Override
public void addInput(Page page) {
checkState(needsInput(), "Operator is already finishing");

try {
GroupingAggregatorFunction.AddInput[] prepared = new GroupingAggregatorFunction.AddInput[aggregators.size()];
class AddInput implements GroupingAggregatorFunction.AddInput {
Expand Down Expand Up @@ -156,16 +167,21 @@ public void close() {
Releasables.closeExpectNoException(prepared);
}
}
Block[] keys = new Block[groups.size()];
page = wrapPage(page);
try (AddInput add = new AddInput()) {
checkState(needsInput(), "Operator is already finishing");
requireNonNull(page, "page is null");
for (int g = 0; g < groups.size(); g++) {
keys[g] = groups.get(g).eval(page);
}

for (int i = 0; i < prepared.length; i++) {
prepared[i] = aggregators.get(i).prepareProcessPage(blockHash, page);
}

blockHash.add(wrapPage(page), add);
blockHash.add(new Page(keys), add);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to modify BlockHash to take a Block[] with the blocks in the right position. But that seems like something for another time.

hashNanos += System.nanoTime() - add.hashStart;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably worth timing the evaluation here.

} finally {
Releasables.close(keys);
}
} finally {
page.releaseBlocks();
Expand All @@ -192,15 +208,29 @@ public void finish() {
try {
selected = blockHash.nonEmpty();
Block[] keys = blockHash.getKeys();
int[] aggBlockCounts = aggregators.stream().mapToInt(GroupingAggregator::evaluateBlockCount).toArray();
blocks = new Block[keys.length + Arrays.stream(aggBlockCounts).sum()];
System.arraycopy(keys, 0, blocks, 0, keys.length);
int offset = keys.length;
for (int i = 0; i < aggregators.size(); i++) {
var aggregator = aggregators.get(i);
aggregator.evaluate(blocks, offset, selected, driverContext);
offset += aggBlockCounts[i];

int blockCount = 0;
for (int g = 0; g < groups.size(); g++) {
blockCount += groups.get(g).finishBlockCount();
}
int[] aggBlockCounts = new int[aggregators.size()];
for (int a = 0; a < aggregators.size(); a++) {
aggBlockCounts[a] = aggregators.get(a).evaluateBlockCount();
blockCount += aggBlockCounts[a];
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it a lot easier to read if I encoded the resultOffsets into the GroupKeys. It'd be even easier to read if the offsets were encoded into the aggregators too. Or if we returns Block[].

}

blocks = new Block[blockCount];
int offset = 0;
for (int g = 0; g < groups.size(); g++) {
blocks[offset] = keys[g];
groups.get(g).finish(blocks, selected, driverContext);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No offset passed to the finish() here? How does it know where to place the blocks?

offset += groups.get(g).finishBlockCount();
}
for (int a = 0; a < aggregators.size(); a++) {
aggregators.get(a).evaluate(blocks, offset, selected, driverContext);
offset += aggBlockCounts[a];
}

output = new Page(blocks);
success = true;
} finally {
Expand All @@ -224,7 +254,7 @@ public void close() {
if (output != null) {
output.releaseBlocks();
}
Releasables.close(blockHash, () -> Releasables.close(aggregators));
Releasables.close(blockHash, Releasables.wrap(aggregators), Releasables.wrap(groups));
}

@Override
Expand Down
Loading