Merge Map Spark User Defined Aggregation function - merge two maps of type <String, Long> to one Map.
package com.tomron; | |
import org.apache.spark.sql.Row; | |
import org.apache.spark.sql.expressions.MutableAggregationBuffer; | |
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; | |
import org.apache.spark.sql.types.DataType; | |
import org.apache.spark.sql.types.DataTypes; | |
import org.apache.spark.sql.types.StructField; | |
import org.apache.spark.sql.types.StructType; | |
import java.util.ArrayList; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Map; | |
/** | |
* Created by tomron on 7/10/17. | |
*/ | |
public class MergeMapUDAF extends UserDefinedAggregateFunction { | |
private StructType _inputDataType; | |
private StructType _bufferSchema; | |
private DataType _returnDataType; | |
private static DataType _valueType = DataTypes.LongType; | |
private static DataType _innerKeyType = DataTypes.StringType; | |
private static DataType _outerKeyType = DataTypes.StringType; | |
private static DataType _innerMap = DataTypes.createMapType(_innerKeyType, _valueType); | |
private static DataType _outerMap = DataTypes.createMapType(_outerKeyType, _innerMap); | |
public MergeMapUDAF() { | |
List<StructField> inputFields = new ArrayList<>(); | |
inputFields.add(DataTypes.createStructField("key", _outerKeyType, true)); | |
inputFields.add(DataTypes.createStructField("values", _innerMap, true)); | |
_inputDataType = DataTypes.createStructType(inputFields); | |
List<StructField> bufferFields = new ArrayList<>(); | |
bufferFields.add(DataTypes.createStructField("data", _outerMap, true)); | |
_bufferSchema = DataTypes.createStructType(bufferFields); | |
_returnDataType = _outerMap; | |
} | |
@Override | |
public StructType inputSchema() { | |
return _inputDataType; | |
} | |
@Override | |
public StructType bufferSchema() { | |
return _bufferSchema; | |
} | |
@Override | |
public DataType dataType() { | |
return _returnDataType; | |
} | |
@Override | |
public boolean deterministic() { | |
return false; | |
} | |
@Override | |
public void initialize(MutableAggregationBuffer buffer) { | |
buffer.update(0, new HashMap<String, Map<String, Long>>()); | |
} | |
@Override | |
public void update(MutableAggregationBuffer buffer, Row input) { | |
if (!input.isNullAt(0)) { | |
String inputKey = input.getString(0); | |
Map<String, Long> inputValues = input.<String, Long>getJavaMap(1); | |
Map<String, Map<String, Long>> newData = new HashMap<>(); | |
if (!buffer.isNullAt(0)) { | |
Map<String, Map<String, Long>> currData = buffer.<String, Map<String, Long>>getJavaMap(0); | |
newData.putAll(currData); | |
} | |
newData.put(inputKey, inputValues); | |
buffer.update(0, newData); | |
} | |
} | |
@Override | |
public void merge(MutableAggregationBuffer buffer1, Row buffer2) { | |
Map<String, Map<String, Long>> data1 = buffer1.<String, Map<String, Long>>getJavaMap(0); | |
Map<String, Map<String, Long>> data2 = buffer2.<String, Map<String, Long>>getJavaMap(0); | |
Map<String, Map<String, Long>> newData = new HashMap<>(); | |
newData.putAll(data1); | |
newData.putAll(data2); | |
buffer1.update(0, newData); | |
} | |
@Override | |
public Object evaluate(Row buffer) { | |
return buffer.<String, Map<String, Long>>getJavaMap(0); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment