Skip to content

Instantly share code, notes, and snippets.

@dnhatn
Created April 24, 2024 04:07
Show Gist options
  • Save dnhatn/b48d36cc9d3e93edc6f325e635e01333 to your computer and use it in GitHub Desktop.
Save dnhatn/b48d36cc9d3e93edc6f325e635e01333 to your computer and use it in GitHub Desktop.
MetricsAggregationOperator.java
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.operator;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.ValuesBooleanAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.ValuesDoubleAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.ValuesIntAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.ValuesLongAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.aggregation.blockhash.TimeSeriesBlockHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasables;
import java.util.ArrayList;
import java.util.List;
public final class TimeSeriesAggregationOperator extends HashAggregationOperator {
public record Factory(
AggregatorMode mode,
int tsHashChannel,
int timeBucketChannel,
List<BlockHash.GroupSpec> groupings,
List<GroupingAggregator.Factory> rateFactories,
List<AggregatorFunctionSupplier> outerRateFactories,
List<GroupingAggregator.Factory> nonRateFactories,
int[] blocksReordering,
int maxPageSize
) implements Operator.OperatorFactory {
@Override
public Operator get(DriverContext driverContext) {
return new TimeSeriesAggregationOperator(
mode,
tsHashChannel,
timeBucketChannel,
groupings,
rateFactories,
outerRateFactories,
nonRateFactories,
blocksReordering,
maxPageSize,
driverContext
);
}
@Override
public String describe() {
return "TimeSeriesAggregationOperator[mode="
+ mode
+ ", tsHashChannel = "
+ tsHashChannel
+ ", timeBucketChannel = "
+ timeBucketChannel
+ ", maxPageSize = "
+ maxPageSize
+ "]";
}
}
private final AggregatorMode mode;
private final int timeBucketChannel;
private final int maxPageSize;
private final List<BlockHash.GroupSpec> groupings;
private final List<GroupingAggregator.Factory> nonRateFactories;
private final List<GroupingAggregator.Factory> rateFactories;
private final List<AggregatorFunctionSupplier> outerRateFactories;
private final int[] blocksReordering;
public TimeSeriesAggregationOperator(
AggregatorMode mode,
int tsHashChannel,
int timeBucketChannel,
List<BlockHash.GroupSpec> groupings,
List<GroupingAggregator.Factory> rateFactories,
List<AggregatorFunctionSupplier> outerRateFactories,
List<GroupingAggregator.Factory> nonRateFactories,
int[] blocksReordering,
int maxPageSize,
DriverContext driverContext
) {
super(
mergeGroupingAggregatorFactories(mode, timeBucketChannel, rateFactories, nonRateFactories, groupings),
() -> mainBlockHash(mode, tsHashChannel, timeBucketChannel, driverContext, maxPageSize),
driverContext
);
this.mode = mode;
this.timeBucketChannel = timeBucketChannel;
this.nonRateFactories = nonRateFactories;
this.rateFactories = rateFactories;
this.outerRateFactories = outerRateFactories;
this.groupings = groupings;
this.blocksReordering = blocksReordering;
this.maxPageSize = maxPageSize;
}
private static List<GroupingAggregator.Factory> mergeGroupingAggregatorFactories(
int timeBucketChannel,
List<GroupingAggregator.Factory> rates,
List<GroupingAggregator.Factory> nonRates,
List<BlockHash.GroupSpec> groupings
) {
List<GroupingAggregator.Factory> factories = new ArrayList<>(rates.size() + nonRates.size() + groupings.size());
factories.addAll(rates);
factories.addAll(nonRates);
for (BlockHash.GroupSpec groupSpec : groupings) {
if (groupSpec.channel() != timeBucketChannel) {
final List<Integer> channels = List.of(groupSpec.channel());
// TODO: maybe introduce a specialized aggregator instead reusing values
var aggregatorSupplier = (switch (groupSpec.elementType()) {
case BYTES_REF -> new ValuesBytesRefAggregatorFunctionSupplier(channels);
case DOUBLE -> new ValuesDoubleAggregatorFunctionSupplier(channels);
case INT -> new ValuesIntAggregatorFunctionSupplier(channels);
case LONG -> new ValuesLongAggregatorFunctionSupplier(channels);
case BOOLEAN -> new ValuesBooleanAggregatorFunctionSupplier(channels);
case NULL, DOC, UNKNOWN -> throw new IllegalArgumentException("unsupported grouping type");
});
factories.add(aggregatorSupplier.groupingAggregatorFactory(AggregatorMode.SINGLE));
}
}
return factories;
}
private static BlockHash mainBlockHash(
AggregatorMode mode,
int tsHashChannel,
int timeBucketChannel,
DriverContext driverContext,
int maxPageSize
) {
if (mode.isInputPartial() == false) {
return new TimeSeriesBlockHash(tsHashChannel, timeBucketChannel, driverContext);
} else {
return BlockHash.build(
List.of(
new BlockHash.GroupSpec(tsHashChannel, ElementType.BYTES_REF),
new BlockHash.GroupSpec(timeBucketChannel, ElementType.LONG)
),
driverContext.blockFactory(),
maxPageSize,
false
);
}
}
/**
* Re-hashes the output with user-specified grouping keys, handling rate and non-rate aggregations differently:
* - For rate aggregators, applies outer aggregation over rate final results.
* - For non-rate aggregators, creates new aggregators and merges raw final rows.
* Keys are extracted from the values aggregators.
*/
@Override
protected Page generateOutputPage(BlockHash blockHash, List<GroupingAggregator> inputAggregators) {
if (mode.isOutputPartial()) {
return super.generateOutputPage(blockHash, inputAggregators);
}
final Block[] inputBlocks = new Block[rateFactories.size() + groupings.size()];
Page inputPage = null;
try (IntVector selected = blockHash.nonEmpty()) {
loadRateBlocks(inputBlocks, selected, inputAggregators);
loadGroupingKeys(inputBlocks, selected, blockHash, inputAggregators);
inputPage = new Page(inputBlocks);
final Page outputPage = rehash(inputPage, inputAggregators);
try {
final Block[] outputBlocks = new Block[outputPage.getBlockCount()];
for (int i = 0; i < outputBlocks.length; i++) {
outputBlocks[i] = outputPage.getBlock(blocksReordering[i]);
outputBlocks[i].incRef();
}
return new Page(outputBlocks);
} finally {
outputPage.releaseBlocks();
}
} finally {
if (inputPage == null) {
Releasables.close(inputBlocks);
}
}
}
Page rehash(Page page, List<GroupingAggregator> inputAggregators) {
List<GroupingAggregator> outputAggregators = new ArrayList<>(outerRateFactories.size() + nonRateFactories.size());
try (BlockHash rehash = regroupBlockHash(driverContext)) {
List<GroupingAggregator> outerAggregators = createOuterRateAggregators(driverContext);
outputAggregators.addAll(outerAggregators);
List<GroupingAggregator> nonRateAggregators = createAggregators(nonRateFactories, driverContext);
outputAggregators.addAll(nonRateAggregators);
GroupingAggregatorFunction.AddInput[] outerRates = new GroupingAggregatorFunction.AddInput[outerRateFactories.size()];
for (int i = 0; i < outerRates.length; i++) {
outerRates[i] = outerAggregators.get(i).prepareProcessPage(rehash, page);
}
class AddInput implements GroupingAggregatorFunction.AddInput {
private final int nonRateIndex = rateFactories.size();
@Override
public void add(int positionOffset, IntBlock groupIds) {
IntVector groupIdsVector = groupIds.asVector();
if (groupIdsVector != null) {
add(positionOffset, groupIdsVector);
} else {
for (int i = 0; i < nonRateAggregators.size(); i++) {
combineAggregator(nonRateAggregators.get(i), inputAggregators.get(nonRateIndex + i), positionOffset, groupIds);
}
for (GroupingAggregatorFunction.AddInput outerRate : outerRates) {
outerRate.add(positionOffset, groupIds);
}
}
}
@Override
public void add(int positionOffset, IntVector groupIds) {
for (int i = 0; i < nonRateAggregators.size(); i++) {
combineAggregator(nonRateAggregators.get(i), inputAggregators.get(nonRateIndex + i), positionOffset, groupIds);
}
for (GroupingAggregatorFunction.AddInput outerRate : outerRates) {
outerRate.add(positionOffset, groupIds);
}
}
}
AddInput add = new AddInput();
rehash.add(page, add);
return super.generateOutputPage(rehash, outputAggregators);
} finally {
Releasables.close(Releasables.wrap(outputAggregators), page::releaseBlocks);
}
}
private BlockHash regroupBlockHash(DriverContext driverContext) {
List<BlockHash.GroupSpec> specs = new ArrayList<>();
for (int i = 0; i < groupings.size(); i++) {
specs.add(new BlockHash.GroupSpec(i + rateFactories.size(), groupings.get(i).elementType()));
}
return BlockHash.build(specs, driverContext.blockFactory(), maxPageSize, false);
}
void loadGroupingKeys(Block[] blocks, IntVector selected, BlockHash blockHash, List<GroupingAggregator> inputAggregators) {
int blockOffset = rateFactories.size();
int aggregatorOffset = rateFactories.size() + nonRateFactories.size();
for (BlockHash.GroupSpec group : groupings) {
if (group.channel() == timeBucketChannel) {
Block[] hashKeys = blockHash.getKeys();
hashKeys[blockOffset++] = hashKeys[1];
hashKeys[0].close(); // tsid
} else {
final GroupingAggregator aggregator = inputAggregators.get(aggregatorOffset++);
if (aggregator.evaluateBlockCount() != 1) {
throw new IllegalStateException("expected one output block; got " + aggregator.evaluateBlockCount());
}
aggregator.evaluate(blocks, blockOffset++, selected, driverContext);
}
}
}
void loadRateBlocks(Block[] rateBlocks, IntVector selected, List<GroupingAggregator> inputAggregators) {
for (int i = 0; i < rateFactories.size(); i++) {
GroupingAggregator aggregator = inputAggregators.get(i);
if (aggregator.evaluateBlockCount() != 1) {
throw new IllegalStateException("expected one output block; got " + aggregator.evaluateBlockCount());
}
aggregator.evaluate(rateBlocks, i, selected, driverContext);
}
}
private List<GroupingAggregator> createOuterRateAggregators(DriverContext driverContext) {
final List<GroupingAggregator> aggregators = new ArrayList<>(outerRateFactories.size());
try {
for (AggregatorFunctionSupplier supplier : outerRateFactories) {
aggregators.add(supplier.groupingAggregatorFactory(AggregatorMode.SINGLE).apply(driverContext));
}
return aggregators;
} finally {
if (aggregators.size() != outerRateFactories.size()) {
Releasables.close(aggregators);
}
}
}
private static List<GroupingAggregator> createAggregators(List<GroupingAggregator.Factory> factories, DriverContext driverContext) {
final List<GroupingAggregator> aggregators = new ArrayList<>(factories.size());
try {
for (GroupingAggregator.Factory factory : factories) {
aggregators.add(factory.apply(driverContext));
}
return aggregators;
} finally {
if (aggregators.size() != factories.size()) {
Releasables.close(aggregators);
}
}
}
private static void combineAggregator(GroupingAggregator out, GroupingAggregator in, int positionOffset, IntBlock groups) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
if (groups.isNull(groupPosition)) {
continue;
}
int groupStart = groups.getFirstValueIndex(groupPosition);
int groupEnd = groupStart + groups.getValueCount(groupPosition);
for (int g = groupStart; g < groupEnd; g++) {
int groupId = Math.toIntExact(groups.getInt(g));
out.addIntermediateRow(groupId, in, groupPosition + positionOffset);
}
}
}
private static void combineAggregator(GroupingAggregator out, GroupingAggregator in, int positionOffset, IntVector groups) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
int groupId = Math.toIntExact(groups.getInt(groupPosition));
out.addIntermediateRow(groupId, in, groupPosition + positionOffset);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment