-
Notifications
You must be signed in to change notification settings - Fork 24.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ESQL: Compute engine support for stateful grouping functions #112757
base: main
Are you sure you want to change the base?
Changes from all commits
bc389e4
712b672
37415d2
1fb09b1
5a088b6
522c0bc
96e6505
f6ef350
a385441
5a46823
7685a0c
f35fccb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.compute.aggregation; | ||
|
||
import org.elasticsearch.compute.aggregation.blockhash.BlockHash; | ||
import org.elasticsearch.compute.data.Block; | ||
import org.elasticsearch.compute.data.BlockFactory; | ||
import org.elasticsearch.compute.data.ElementType; | ||
import org.elasticsearch.compute.data.IntVector; | ||
import org.elasticsearch.compute.data.Page; | ||
import org.elasticsearch.compute.operator.DriverContext; | ||
import org.elasticsearch.compute.operator.EvalOperator; | ||
import org.elasticsearch.core.Releasable; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
public record GroupingKey(AggregatorMode mode, Thing thing, BlockFactory blockFactory) implements EvalOperator.ExpressionEvaluator { | ||
public interface Thing extends Releasable { | ||
int extraIntermediateBlocks(); | ||
|
||
Block evalRawInput(Page page); | ||
|
||
Block evalIntermediateInput(BlockFactory blockFactory, Page page); | ||
|
||
void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int positionCount); | ||
|
||
void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks); | ||
} | ||
|
||
public interface Supplier { | ||
Factory get(AggregatorMode mode); | ||
} | ||
|
||
public interface Factory { | ||
GroupingKey apply(DriverContext context, int resultOffset); | ||
|
||
ElementType intermediateElementType(); | ||
|
||
GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int timeBucketChannel); | ||
} | ||
|
||
public static GroupingKey.Supplier forStatelessGrouping(int channel, ElementType elementType) { | ||
return mode -> new Factory() { | ||
@Override | ||
public GroupingKey apply(DriverContext context, int resultOffset) { | ||
return new GroupingKey(mode, new Load(channel, resultOffset), context.blockFactory()); | ||
} | ||
|
||
@Override | ||
public ElementType intermediateElementType() { | ||
return elementType; | ||
} | ||
|
||
@Override | ||
public GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int timeBucketChannel) { | ||
if (channel != timeBucketChannel) { | ||
final List<Integer> channels = List.of(channel); | ||
// TODO: perhaps introduce a specialized aggregator for this? | ||
return (switch (intermediateElementType()) { | ||
case BYTES_REF -> new ValuesBytesRefAggregatorFunctionSupplier(channels); | ||
case DOUBLE -> new ValuesDoubleAggregatorFunctionSupplier(channels); | ||
case INT -> new ValuesIntAggregatorFunctionSupplier(channels); | ||
case LONG -> new ValuesLongAggregatorFunctionSupplier(channels); | ||
case BOOLEAN -> new ValuesBooleanAggregatorFunctionSupplier(channels); | ||
case FLOAT, NULL, DOC, COMPOSITE, UNKNOWN -> throw new IllegalArgumentException("unsupported grouping type"); | ||
}).groupingAggregatorFactory(AggregatorMode.SINGLE); | ||
} | ||
return null; | ||
} | ||
}; | ||
} | ||
|
||
public static List<BlockHash.GroupSpec> toBlockHashGroupSpec(List<GroupingKey.Factory> keys) { | ||
List<BlockHash.GroupSpec> result = new ArrayList<>(keys.size()); | ||
for (int k = 0; k < keys.size(); k++) { | ||
result.add(new BlockHash.GroupSpec(k, keys.get(k).intermediateElementType())); | ||
} | ||
return result; | ||
} | ||
|
||
@Override | ||
public Block eval(Page page) { | ||
return mode.isInputPartial() ? thing.evalIntermediateInput(blockFactory, page) : thing.evalRawInput(page); | ||
} | ||
|
||
public int finishBlockCount() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: We're calling this "finish", while in the aggregator it's "evaluate". Some reason to keep those names separated? From what I understand, the operation is nearly the same (?) |
||
return mode.isOutputPartial() ? 1 + thing.extraIntermediateBlocks() : 1; | ||
} | ||
|
||
public void finish(Block[] blocks, IntVector selected, DriverContext driverContext) { | ||
if (mode.isOutputPartial()) { | ||
thing.fetchIntermediateState(driverContext.blockFactory(), blocks, selected.getPositionCount()); | ||
} else { | ||
thing.replaceIntermediateKeys(driverContext.blockFactory(), blocks); | ||
} | ||
} | ||
|
||
public int extraIntermediateBlocks() { | ||
return thing.extraIntermediateBlocks(); | ||
} | ||
|
||
@Override | ||
public void close() { | ||
thing.close(); | ||
} | ||
|
||
private record Load(int channel, int resultOffset) implements Thing { | ||
@Override | ||
public int extraIntermediateBlocks() { | ||
return 0; | ||
} | ||
|
||
@Override | ||
public Block evalRawInput(Page page) { | ||
Block b = page.getBlock(channel); | ||
b.incRef(); | ||
return b; | ||
} | ||
|
||
@Override | ||
public Block evalIntermediateInput(BlockFactory blockFactory, Page page) { | ||
Block b = page.getBlock(resultOffset); | ||
b.incRef(); | ||
return b; | ||
} | ||
|
||
@Override | ||
public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int positionCount) {} | ||
|
||
@Override | ||
public void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks) {} | ||
|
||
@Override | ||
public void close() {} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
import org.elasticsearch.compute.Describable; | ||
import org.elasticsearch.compute.aggregation.GroupingAggregator; | ||
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; | ||
import org.elasticsearch.compute.aggregation.GroupingKey; | ||
import org.elasticsearch.compute.aggregation.blockhash.BlockHash; | ||
import org.elasticsearch.compute.data.Block; | ||
import org.elasticsearch.compute.data.IntBlock; | ||
|
@@ -27,26 +28,24 @@ | |
|
||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Objects; | ||
import java.util.function.Supplier; | ||
|
||
import static java.util.Objects.requireNonNull; | ||
import static java.util.stream.Collectors.joining; | ||
|
||
public class HashAggregationOperator implements Operator { | ||
|
||
public record HashAggregationOperatorFactory( | ||
List<BlockHash.GroupSpec> groups, | ||
List<GroupingKey.Factory> groups, | ||
List<GroupingAggregator.Factory> aggregators, | ||
int maxPageSize | ||
) implements OperatorFactory { | ||
@Override | ||
public Operator get(DriverContext driverContext) { | ||
return new HashAggregationOperator( | ||
aggregators, | ||
() -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false), | ||
groups, | ||
() -> BlockHash.build(GroupingKey.toBlockHashGroupSpec(groups), driverContext.blockFactory(), maxPageSize, false), | ||
driverContext | ||
); | ||
} | ||
|
@@ -61,15 +60,17 @@ public String describe() { | |
} | ||
} | ||
|
||
private boolean finished; | ||
private Page output; | ||
private final List<GroupingAggregator> aggregators; | ||
|
||
private final BlockHash blockHash; | ||
private final List<GroupingKey> groups; | ||
|
||
private final List<GroupingAggregator> aggregators; | ||
private final BlockHash blockHash; | ||
|
||
private final DriverContext driverContext; | ||
|
||
private boolean finished; | ||
private Page output; | ||
|
||
/** | ||
* Nanoseconds this operator has spent hashing grouping keys. | ||
*/ | ||
|
@@ -86,17 +87,25 @@ public String describe() { | |
@SuppressWarnings("this-escape") | ||
public HashAggregationOperator( | ||
List<GroupingAggregator.Factory> aggregators, | ||
List<GroupingKey.Factory> groups, | ||
Supplier<BlockHash> blockHash, | ||
DriverContext driverContext | ||
) { | ||
this.aggregators = new ArrayList<>(aggregators.size()); | ||
this.groups = new ArrayList<>(groups.size()); | ||
this.driverContext = driverContext; | ||
boolean success = false; | ||
try { | ||
this.blockHash = blockHash.get(); | ||
for (GroupingAggregator.Factory a : aggregators) { | ||
this.aggregators.add(a.apply(driverContext)); | ||
} | ||
int offset = 0; | ||
for (GroupingKey.Factory g : groups) { | ||
GroupingKey key = g.apply(driverContext, offset); | ||
this.groups.add(key); | ||
offset += key.extraIntermediateBlocks() + 1; | ||
} | ||
success = true; | ||
} finally { | ||
if (success == false) { | ||
|
@@ -112,6 +121,8 @@ public boolean needsInput() { | |
|
||
@Override | ||
public void addInput(Page page) { | ||
checkState(needsInput(), "Operator is already finishing"); | ||
|
||
try { | ||
GroupingAggregatorFunction.AddInput[] prepared = new GroupingAggregatorFunction.AddInput[aggregators.size()]; | ||
class AddInput implements GroupingAggregatorFunction.AddInput { | ||
|
@@ -156,16 +167,21 @@ public void close() { | |
Releasables.closeExpectNoException(prepared); | ||
} | ||
} | ||
Block[] keys = new Block[groups.size()]; | ||
page = wrapPage(page); | ||
try (AddInput add = new AddInput()) { | ||
checkState(needsInput(), "Operator is already finishing"); | ||
requireNonNull(page, "page is null"); | ||
for (int g = 0; g < groups.size(); g++) { | ||
keys[g] = groups.get(g).eval(page); | ||
} | ||
|
||
for (int i = 0; i < prepared.length; i++) { | ||
prepared[i] = aggregators.get(i).prepareProcessPage(blockHash, page); | ||
} | ||
|
||
blockHash.add(wrapPage(page), add); | ||
blockHash.add(new Page(keys), add); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to modify |
||
hashNanos += System.nanoTime() - add.hashStart; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's probably worth timing the evaluation here. |
||
} finally { | ||
Releasables.close(keys); | ||
} | ||
} finally { | ||
page.releaseBlocks(); | ||
|
@@ -192,15 +208,29 @@ public void finish() { | |
try { | ||
selected = blockHash.nonEmpty(); | ||
Block[] keys = blockHash.getKeys(); | ||
int[] aggBlockCounts = aggregators.stream().mapToInt(GroupingAggregator::evaluateBlockCount).toArray(); | ||
blocks = new Block[keys.length + Arrays.stream(aggBlockCounts).sum()]; | ||
System.arraycopy(keys, 0, blocks, 0, keys.length); | ||
int offset = keys.length; | ||
for (int i = 0; i < aggregators.size(); i++) { | ||
var aggregator = aggregators.get(i); | ||
aggregator.evaluate(blocks, offset, selected, driverContext); | ||
offset += aggBlockCounts[i]; | ||
|
||
int blockCount = 0; | ||
for (int g = 0; g < groups.size(); g++) { | ||
blockCount += groups.get(g).finishBlockCount(); | ||
} | ||
int[] aggBlockCounts = new int[aggregators.size()]; | ||
for (int a = 0; a < aggregators.size(); a++) { | ||
aggBlockCounts[a] = aggregators.get(a).evaluateBlockCount(); | ||
blockCount += aggBlockCounts[a]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found it a lot easier to read if I encoded the |
||
} | ||
|
||
blocks = new Block[blockCount]; | ||
int offset = 0; | ||
for (int g = 0; g < groups.size(); g++) { | ||
blocks[offset] = keys[g]; | ||
groups.get(g).finish(blocks, selected, driverContext); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No offset passed to the |
||
offset += groups.get(g).finishBlockCount(); | ||
} | ||
for (int a = 0; a < aggregators.size(); a++) { | ||
aggregators.get(a).evaluate(blocks, offset, selected, driverContext); | ||
offset += aggBlockCounts[a]; | ||
} | ||
|
||
output = new Page(blocks); | ||
success = true; | ||
} finally { | ||
|
@@ -224,7 +254,7 @@ public void close() { | |
if (output != null) { | ||
output.releaseBlocks(); | ||
} | ||
Releasables.close(blockHash, () -> Releasables.close(aggregators)); | ||
Releasables.close(blockHash, Releasables.wrap(aggregators), Releasables.wrap(groups)); | ||
} | ||
|
||
@Override | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "Thing" the final name here?