Created
April 24, 2024 04:07
-
-
Save dnhatn/b48d36cc9d3e93edc6f325e635e01333 to your computer and use it in GitHub Desktop.
MetricsAggregationOperator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | |
* or more contributor license agreements. Licensed under the Elastic License | |
* 2.0; you may not use this file except in compliance with the Elastic License | |
* 2.0. | |
*/ | |
package org.elasticsearch.compute.operator; | |
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; | |
import org.elasticsearch.compute.aggregation.AggregatorMode; | |
import org.elasticsearch.compute.aggregation.GroupingAggregator; | |
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; | |
import org.elasticsearch.compute.aggregation.ValuesBooleanAggregatorFunctionSupplier; | |
import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier; | |
import org.elasticsearch.compute.aggregation.ValuesDoubleAggregatorFunctionSupplier; | |
import org.elasticsearch.compute.aggregation.ValuesIntAggregatorFunctionSupplier; | |
import org.elasticsearch.compute.aggregation.ValuesLongAggregatorFunctionSupplier; | |
import org.elasticsearch.compute.aggregation.blockhash.BlockHash; | |
import org.elasticsearch.compute.aggregation.blockhash.TimeSeriesBlockHash; | |
import org.elasticsearch.compute.data.Block; | |
import org.elasticsearch.compute.data.ElementType; | |
import org.elasticsearch.compute.data.IntBlock; | |
import org.elasticsearch.compute.data.IntVector; | |
import org.elasticsearch.compute.data.Page; | |
import org.elasticsearch.core.Releasables; | |
import java.util.ArrayList; | |
import java.util.List; | |
public final class TimeSeriesAggregationOperator extends HashAggregationOperator { | |
public record Factory( | |
AggregatorMode mode, | |
int tsHashChannel, | |
int timeBucketChannel, | |
List<BlockHash.GroupSpec> groupings, | |
List<GroupingAggregator.Factory> rateFactories, | |
List<AggregatorFunctionSupplier> outerRateFactories, | |
List<GroupingAggregator.Factory> nonRateFactories, | |
int[] blocksReordering, | |
int maxPageSize | |
) implements Operator.OperatorFactory { | |
@Override | |
public Operator get(DriverContext driverContext) { | |
return new TimeSeriesAggregationOperator( | |
mode, | |
tsHashChannel, | |
timeBucketChannel, | |
groupings, | |
rateFactories, | |
outerRateFactories, | |
nonRateFactories, | |
blocksReordering, | |
maxPageSize, | |
driverContext | |
); | |
} | |
@Override | |
public String describe() { | |
return "TimeSeriesAggregationOperator[mode=" | |
+ mode | |
+ ", tsHashChannel = " | |
+ tsHashChannel | |
+ ", timeBucketChannel = " | |
+ timeBucketChannel | |
+ ", maxPageSize = " | |
+ maxPageSize | |
+ "]"; | |
} | |
} | |
private final AggregatorMode mode; | |
private final int timeBucketChannel; | |
private final int maxPageSize; | |
private final List<BlockHash.GroupSpec> groupings; | |
private final List<GroupingAggregator.Factory> nonRateFactories; | |
private final List<GroupingAggregator.Factory> rateFactories; | |
private final List<AggregatorFunctionSupplier> outerRateFactories; | |
private final int[] blocksReordering; | |
public TimeSeriesAggregationOperator( | |
AggregatorMode mode, | |
int tsHashChannel, | |
int timeBucketChannel, | |
List<BlockHash.GroupSpec> groupings, | |
List<GroupingAggregator.Factory> rateFactories, | |
List<AggregatorFunctionSupplier> outerRateFactories, | |
List<GroupingAggregator.Factory> nonRateFactories, | |
int[] blocksReordering, | |
int maxPageSize, | |
DriverContext driverContext | |
) { | |
super( | |
mergeGroupingAggregatorFactories(mode, timeBucketChannel, rateFactories, nonRateFactories, groupings), | |
() -> mainBlockHash(mode, tsHashChannel, timeBucketChannel, driverContext, maxPageSize), | |
driverContext | |
); | |
this.mode = mode; | |
this.timeBucketChannel = timeBucketChannel; | |
this.nonRateFactories = nonRateFactories; | |
this.rateFactories = rateFactories; | |
this.outerRateFactories = outerRateFactories; | |
this.groupings = groupings; | |
this.blocksReordering = blocksReordering; | |
this.maxPageSize = maxPageSize; | |
} | |
private static List<GroupingAggregator.Factory> mergeGroupingAggregatorFactories( | |
int timeBucketChannel, | |
List<GroupingAggregator.Factory> rates, | |
List<GroupingAggregator.Factory> nonRates, | |
List<BlockHash.GroupSpec> groupings | |
) { | |
List<GroupingAggregator.Factory> factories = new ArrayList<>(rates.size() + nonRates.size() + groupings.size()); | |
factories.addAll(rates); | |
factories.addAll(nonRates); | |
for (BlockHash.GroupSpec groupSpec : groupings) { | |
if (groupSpec.channel() != timeBucketChannel) { | |
final List<Integer> channels = List.of(groupSpec.channel()); | |
// TODO: maybe introduce a specialized aggregator instead reusing values | |
var aggregatorSupplier = (switch (groupSpec.elementType()) { | |
case BYTES_REF -> new ValuesBytesRefAggregatorFunctionSupplier(channels); | |
case DOUBLE -> new ValuesDoubleAggregatorFunctionSupplier(channels); | |
case INT -> new ValuesIntAggregatorFunctionSupplier(channels); | |
case LONG -> new ValuesLongAggregatorFunctionSupplier(channels); | |
case BOOLEAN -> new ValuesBooleanAggregatorFunctionSupplier(channels); | |
case NULL, DOC, UNKNOWN -> throw new IllegalArgumentException("unsupported grouping type"); | |
}); | |
factories.add(aggregatorSupplier.groupingAggregatorFactory(AggregatorMode.SINGLE)); | |
} | |
} | |
return factories; | |
} | |
private static BlockHash mainBlockHash( | |
AggregatorMode mode, | |
int tsHashChannel, | |
int timeBucketChannel, | |
DriverContext driverContext, | |
int maxPageSize | |
) { | |
if (mode.isInputPartial() == false) { | |
return new TimeSeriesBlockHash(tsHashChannel, timeBucketChannel, driverContext); | |
} else { | |
return BlockHash.build( | |
List.of( | |
new BlockHash.GroupSpec(tsHashChannel, ElementType.BYTES_REF), | |
new BlockHash.GroupSpec(timeBucketChannel, ElementType.LONG) | |
), | |
driverContext.blockFactory(), | |
maxPageSize, | |
false | |
); | |
} | |
} | |
/** | |
* Re-hashes the output with user-specified grouping keys, handling rate and non-rate aggregations differently: | |
* - For rate aggregators, applies outer aggregation over rate final results. | |
* - For non-rate aggregators, creates new aggregators and merges raw final rows. | |
* Keys are extracted from the values aggregators. | |
*/ | |
@Override | |
protected Page generateOutputPage(BlockHash blockHash, List<GroupingAggregator> inputAggregators) { | |
if (mode.isOutputPartial()) { | |
return super.generateOutputPage(blockHash, inputAggregators); | |
} | |
final Block[] inputBlocks = new Block[rateFactories.size() + groupings.size()]; | |
Page inputPage = null; | |
try (IntVector selected = blockHash.nonEmpty()) { | |
loadRateBlocks(inputBlocks, selected, inputAggregators); | |
loadGroupingKeys(inputBlocks, selected, blockHash, inputAggregators); | |
inputPage = new Page(inputBlocks); | |
final Page outputPage = rehash(inputPage, inputAggregators); | |
try { | |
final Block[] outputBlocks = new Block[outputPage.getBlockCount()]; | |
for (int i = 0; i < outputBlocks.length; i++) { | |
outputBlocks[i] = outputPage.getBlock(blocksReordering[i]); | |
outputBlocks[i].incRef(); | |
} | |
return new Page(outputBlocks); | |
} finally { | |
outputPage.releaseBlocks(); | |
} | |
} finally { | |
if (inputPage == null) { | |
Releasables.close(inputBlocks); | |
} | |
} | |
} | |
Page rehash(Page page, List<GroupingAggregator> inputAggregators) { | |
List<GroupingAggregator> outputAggregators = new ArrayList<>(outerRateFactories.size() + nonRateFactories.size()); | |
try (BlockHash rehash = regroupBlockHash(driverContext)) { | |
List<GroupingAggregator> outerAggregators = createOuterRateAggregators(driverContext); | |
outputAggregators.addAll(outerAggregators); | |
List<GroupingAggregator> nonRateAggregators = createAggregators(nonRateFactories, driverContext); | |
outputAggregators.addAll(nonRateAggregators); | |
GroupingAggregatorFunction.AddInput[] outerRates = new GroupingAggregatorFunction.AddInput[outerRateFactories.size()]; | |
for (int i = 0; i < outerRates.length; i++) { | |
outerRates[i] = outerAggregators.get(i).prepareProcessPage(rehash, page); | |
} | |
class AddInput implements GroupingAggregatorFunction.AddInput { | |
private final int nonRateIndex = rateFactories.size(); | |
@Override | |
public void add(int positionOffset, IntBlock groupIds) { | |
IntVector groupIdsVector = groupIds.asVector(); | |
if (groupIdsVector != null) { | |
add(positionOffset, groupIdsVector); | |
} else { | |
for (int i = 0; i < nonRateAggregators.size(); i++) { | |
combineAggregator(nonRateAggregators.get(i), inputAggregators.get(nonRateIndex + i), positionOffset, groupIds); | |
} | |
for (GroupingAggregatorFunction.AddInput outerRate : outerRates) { | |
outerRate.add(positionOffset, groupIds); | |
} | |
} | |
} | |
@Override | |
public void add(int positionOffset, IntVector groupIds) { | |
for (int i = 0; i < nonRateAggregators.size(); i++) { | |
combineAggregator(nonRateAggregators.get(i), inputAggregators.get(nonRateIndex + i), positionOffset, groupIds); | |
} | |
for (GroupingAggregatorFunction.AddInput outerRate : outerRates) { | |
outerRate.add(positionOffset, groupIds); | |
} | |
} | |
} | |
AddInput add = new AddInput(); | |
rehash.add(page, add); | |
return super.generateOutputPage(rehash, outputAggregators); | |
} finally { | |
Releasables.close(Releasables.wrap(outputAggregators), page::releaseBlocks); | |
} | |
} | |
private BlockHash regroupBlockHash(DriverContext driverContext) { | |
List<BlockHash.GroupSpec> specs = new ArrayList<>(); | |
for (int i = 0; i < groupings.size(); i++) { | |
specs.add(new BlockHash.GroupSpec(i + rateFactories.size(), groupings.get(i).elementType())); | |
} | |
return BlockHash.build(specs, driverContext.blockFactory(), maxPageSize, false); | |
} | |
void loadGroupingKeys(Block[] blocks, IntVector selected, BlockHash blockHash, List<GroupingAggregator> inputAggregators) { | |
int blockOffset = rateFactories.size(); | |
int aggregatorOffset = rateFactories.size() + nonRateFactories.size(); | |
for (BlockHash.GroupSpec group : groupings) { | |
if (group.channel() == timeBucketChannel) { | |
Block[] hashKeys = blockHash.getKeys(); | |
hashKeys[blockOffset++] = hashKeys[1]; | |
hashKeys[0].close(); // tsid | |
} else { | |
final GroupingAggregator aggregator = inputAggregators.get(aggregatorOffset++); | |
if (aggregator.evaluateBlockCount() != 1) { | |
throw new IllegalStateException("expected one output block; got " + aggregator.evaluateBlockCount()); | |
} | |
aggregator.evaluate(blocks, blockOffset++, selected, driverContext); | |
} | |
} | |
} | |
void loadRateBlocks(Block[] rateBlocks, IntVector selected, List<GroupingAggregator> inputAggregators) { | |
for (int i = 0; i < rateFactories.size(); i++) { | |
GroupingAggregator aggregator = inputAggregators.get(i); | |
if (aggregator.evaluateBlockCount() != 1) { | |
throw new IllegalStateException("expected one output block; got " + aggregator.evaluateBlockCount()); | |
} | |
aggregator.evaluate(rateBlocks, i, selected, driverContext); | |
} | |
} | |
private List<GroupingAggregator> createOuterRateAggregators(DriverContext driverContext) { | |
final List<GroupingAggregator> aggregators = new ArrayList<>(outerRateFactories.size()); | |
try { | |
for (AggregatorFunctionSupplier supplier : outerRateFactories) { | |
aggregators.add(supplier.groupingAggregatorFactory(AggregatorMode.SINGLE).apply(driverContext)); | |
} | |
return aggregators; | |
} finally { | |
if (aggregators.size() != outerRateFactories.size()) { | |
Releasables.close(aggregators); | |
} | |
} | |
} | |
private static List<GroupingAggregator> createAggregators(List<GroupingAggregator.Factory> factories, DriverContext driverContext) { | |
final List<GroupingAggregator> aggregators = new ArrayList<>(factories.size()); | |
try { | |
for (GroupingAggregator.Factory factory : factories) { | |
aggregators.add(factory.apply(driverContext)); | |
} | |
return aggregators; | |
} finally { | |
if (aggregators.size() != factories.size()) { | |
Releasables.close(aggregators); | |
} | |
} | |
} | |
private static void combineAggregator(GroupingAggregator out, GroupingAggregator in, int positionOffset, IntBlock groups) { | |
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { | |
if (groups.isNull(groupPosition)) { | |
continue; | |
} | |
int groupStart = groups.getFirstValueIndex(groupPosition); | |
int groupEnd = groupStart + groups.getValueCount(groupPosition); | |
for (int g = groupStart; g < groupEnd; g++) { | |
int groupId = Math.toIntExact(groups.getInt(g)); | |
out.addIntermediateRow(groupId, in, groupPosition + positionOffset); | |
} | |
} | |
} | |
private static void combineAggregator(GroupingAggregator out, GroupingAggregator in, int positionOffset, IntVector groups) { | |
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { | |
int groupId = Math.toIntExact(groups.getInt(groupPosition)); | |
out.addIntermediateRow(groupId, in, groupPosition + positionOffset); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment