-
-
Save alev000/27d10a402ad250957b792091084932f4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.example | |
import org.apache.spark.sql._ | |
import org.apache.spark.sql.catalyst.ScalaReflection | |
import org.apache.spark.sql.catalyst.encoders.RowEncoder | |
import org.apache.spark.sql.types.{StructField, StructType} | |
import scala.collection.mutable | |
import scala.reflect.runtime.universe.{TypeTag, typeTag} | |
case class CustomGroupBy[T, K: Encoder: TypeTag](data: Dataset[T], keyFunc: T => K) { | |
def agg(aggCalcs: GroupingCalculation[T, _, _]*): DataFrame = { | |
val keySchema = ScalaReflection.schemaFor[K] | |
val keyField = StructField("key", keySchema.dataType, keySchema.nullable) | |
val outputFields = aggCalcs.map { calc => | |
StructField(calc.colName, calc.resultSchema.dataType, calc.resultSchema.nullable) | |
} | |
val combinedStruct = StructType(keyField +: outputFields) | |
println("DEBUGGING:") // TODO: REMOVE | |
println(s"type tag of key: ${typeTag[K]}") | |
println(s"keyField: ${keyField}") | |
println(s"outputFields: ${outputFields}") | |
println(s"combinedStruct: ${combinedStruct}") | |
val rowEnc: Encoder[Row] = RowEncoder(combinedStruct) | |
data | |
.groupByKey(keyFunc) | |
.mapGroups { (k: K, ts: Iterator[T]) => | |
val state = mutable.ArrayBuffer[Any](aggCalcs.map { _.init() } : _*) | |
for (t <- ts) { | |
for ((calc, i) <- aggCalcs.zipWithIndex) { | |
state(i) = calc.doUpdate(t, state(i)) | |
} | |
} | |
val results = mutable.ArrayBuffer[Any]() | |
results.append(k) | |
aggCalcs.zipWithIndex.foreach { x: (GroupingCalculation[T, _, _], Int) => | |
val (calc, i) = x | |
val res = calc.doEvaluate(state(i)) | |
results.append(res) | |
} | |
Row(results : _*) | |
}(rowEnc) | |
} | |
} | |
class GroupingCalculation[T, S, R: TypeTag: Encoder] private[example]( | |
private[example] val colName: String, | |
private[example] val init: () => S, | |
private[example] val update: (T, S) => S, | |
private[example] val evaluate: S => R) extends Serializable { | |
def as(newName: String): GroupingCalculation[T, S, R] = { | |
new GroupingCalculation[T, S, R](newName, init, update, evaluate) | |
} | |
val resultSchema: ScalaReflection.Schema = ScalaReflection.schemaFor[R] | |
// private val resultRowConverter: InternalRow => Row = CatalystTypeConverters.createToScalaConverter(resultSchema.dataType).asInstanceOf[InternalRow => Row] | |
// private val resultExprEnc: ExpressionEncoder[R] = encoderFor[R].resolveAndBind() | |
def doUpdate(t: T, x: Any): S = update(t, x.asInstanceOf[S]) | |
def doEvaluate(x: Any): R = { | |
val result: R = evaluate(x.asInstanceOf[S]) | |
result | |
//val resultEncoded: Row = resultRowConverter(resultExprEnc.toRow(result)) | |
//resultEncoded | |
} | |
} | |
object GroupingCalculation { | |
def apply[T, S, R: TypeTag: Encoder]( | |
colName: String, | |
init: () => S, | |
update: (T, S) => S, | |
evaluate: S => R | |
): GroupingCalculation[T, S, R] = { | |
new GroupingCalculation(colName, init, update, evaluate) | |
} | |
def apply[T, S: TypeTag: Encoder]( | |
colName: String, | |
init: () => S, | |
update: (T, S) => S, | |
merge: (S, S) => S | |
): GroupingCalculation[T, S, S] = { | |
apply(colName, init, update, evaluate = identity[S]) | |
} | |
def firstAndOnly[T, S: TypeTag] | |
(colName: String, getter: T => S) | |
(implicit enc: Encoder[Option[S]]): GroupingCalculation[T, Option[S], Option[S]] = { | |
new GroupingCalculation[T, Option[S], Option[S]]( | |
colName = colName, | |
init = () => Option.empty[S], | |
update = { (item: T, accum: Option[S]) => | |
val current = getter(item) | |
accum match { | |
case None if current == null => None | |
case None => Some(current) | |
case Some(x) if x == current => accum | |
case _ => throw new RuntimeException(s"Mismatch found in firstAndOnly update for ${colName}: ${current} versus ${accum}") | |
} | |
}, | |
evaluate = identity[Option[S]] | |
) | |
} | |
def fromCountsBy[T, G, R: TypeTag: Encoder] | |
(colName: String, byGetter: T => G) | |
(evaluate: Map[G, Long] => R): GroupingCalculation[T, Map[G, Long], R] = { | |
GroupingCalculation.apply[T, Map[G, Long], R]( | |
colName = colName, | |
init = () => Map.empty[G, Long], | |
update = (item: T, accum: Map[G, Long]) => { | |
val by = byGetter(item) | |
accum + ((by, accum.getOrElse(by, 0L) + 1L)) | |
}, | |
evaluate = evaluate | |
) | |
} | |
def count[T](colName: String): GroupingCalculation[T, Long, Long] = { | |
GroupingCalculation[T, Long, Long]( | |
colName, | |
init = () => 0L, | |
update = (_: T, accum: Long) => 1 + accum, | |
evaluate = identity[Long] | |
)(typeTag[Long], Encoders.scalaLong) | |
} | |
def countDistinct[T, S](s: String, getter: T => S): GroupingCalculation[T, Set[S], Long] = { | |
GroupingCalculation[T, Set[S], Long]( | |
colName = s, | |
init = () => Set[S](), | |
update = (item: T, accum: Set[S]) => accum + getter(item), | |
evaluate = (accum: Set[S]) => accum.size | |
)(typeTag[Long], Encoders.scalaLong) | |
} | |
implicit class CustomGroupByHolder[T](data: Dataset[T]) { | |
def customGroupBy[K: Encoder: TypeTag](keyFunc: T => K): CustomGroupBy[T, K] = { | |
CustomGroupBy(data, keyFunc) | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.example | |
import com.holdenkarau.spark.testing.DatasetSuiteBase | |
import org.apache.spark.SparkConf | |
import org.apache.spark.sql.{Dataset, Row} | |
import org.scalatest.FunSuite | |
case class TestCase1(a: Int, b: String, c: Long, d: Double, e: Array[Int], f: Set[Int], g: Map[Long, String]) | |
case class TestCase2(a: TestCase3, b: TestCase3) | |
case class TestCase3(a: Int, b: Double) | |
case class TestCase4(key: (Int, String), test: TestCase3) | |
class GroupingCalculationTest extends FunSuite with DatasetSuiteBase { | |
override def conf: SparkConf = { | |
new SparkConf(). | |
setMaster("local[1]"). | |
setAppName("test"). | |
set("spark.ui.enabled", "false"). | |
set("spark.app.id", appID). | |
set("spark.driver.host", "localhost") | |
} | |
import GroupingCalculation.CustomGroupByHolder | |
import spark.implicits._ | |
lazy val data : Dataset[TestCase1] = Seq( | |
TestCase1(1, "b1", 1L, 11.0, Array(), Set(), Map(1L -> "x")), | |
TestCase1(1, "b1", 1L, 21.0, Array(0), Set(), Map(1L -> "x")), | |
TestCase1(1, "b1", 1L, 21.0, Array(0), Set(), Map(1L -> "x")), | |
TestCase1(2, "b2", 1L, 31.0, Array(1), Set(), Map(1L -> "x")), | |
TestCase1(2, "b3", 1L, 41.0, Array(1), Set(), Map(1L -> "x")) | |
).toDS | |
test("Test basic grouping calculations") { | |
val calcs: Seq[GroupingCalculation[TestCase1, _, _]] = Seq( | |
GroupingCalculation("sum", | |
init = () => 0, | |
update = (x: TestCase1, accum: Int) => x.a + accum, | |
merge = (a1: Int, a2: Int) => a1 + a2 | |
), | |
GroupingCalculation("product", | |
init = () => 1.0, | |
update = (x: TestCase1, accum: Double) => x.d * accum, | |
merge = (a1: Double, a2: Double) => a1 * a2 | |
), | |
GroupingCalculation("concat", | |
init = () => "", | |
update = (x: TestCase1, accum: String) => x.b + accum, | |
merge = (a1: String, a2: String) => a1 + a2 | |
), | |
GroupingCalculation("count", | |
init = () => 0L, | |
update = (_: TestCase1, accum: Long) => 1 + accum, | |
merge = (a1: Long, a2: Long) => a1 + a2 | |
) | |
) | |
val result_df = data | |
.customGroupBy(x => (x.a, x.b)) | |
.agg(calcs: _*) | |
.sort('key("_1"), 'key("_2")) | |
println("DEBUGGING:") // TODO: REMOVE | |
result_df.explain() | |
val result = result_df.collect | |
val expected = Array( | |
Row(Row(1, "b1"), 3L, 11.0 * 21.0 * 21.0, "b1b1b1", 3), | |
Row(Row(2, "b2"), 2L, 31.0, "b2", 1), | |
Row(Row(2, "b3"), 2L, 41.0, "b3", 1) | |
) | |
assert(result_df.columns === Seq("key", "sum", "product", "concat", "count")) | |
assert(result === expected) | |
} | |
test("Test collection-based grouping calculations") { | |
val calcs: Seq[GroupingCalculation[TestCase1, _, _]] = Seq( | |
GroupingCalculation("count_distinct_d_via_seq", | |
init = () => Seq[Double](), | |
update = (x: TestCase1, accum: Seq[Double]) => accum :+ x.d, | |
evaluate = (accum: Seq[Double]) => accum.toSet.size | |
), | |
GroupingCalculation("count_distinct_d_via_map", | |
init = () => Map[Double, Int](), | |
update = (x: TestCase1, accum: Map[Double, Int]) => accum + ((x.d, 1)), | |
evaluate = (accum: Map[Double, Int]) => accum.size | |
), | |
GroupingCalculation("count_duplicates_d", | |
init = () => Map[Double, Int](), | |
update = (x: TestCase1, accum: Map[Double, Int]) => accum + ((x.d, accum.getOrElse(x.d, 0) + 1)), | |
evaluate = (accum: Map[Double, Int]) => accum | |
), | |
GroupingCalculation("collect_set_d", | |
init = () => Map[Double, Int](), | |
update = (x: TestCase1, accum: Map[Double, Int]) => accum + ((x.d, 1)), | |
evaluate = (accum: Map[Double, Int]) => accum.keys.toSeq.sorted | |
), | |
GroupingCalculation("map_with_string", | |
init = () => Map[String, Int](), | |
update = (x: TestCase1, accum: Map[String, Int]) => { | |
accum ++ x.g.values.map(y => (y, 1)) | |
}, | |
evaluate = (accum: Map[String, Int]) => accum.size | |
) | |
) | |
val result_df = data | |
.customGroupBy(x => (x.a, x.b)) | |
.agg(calcs: _*) | |
.sort('key("_1"), 'key("_2")) | |
println("DEBUGGING:") // TODO: REMOVE | |
result_df.explain() | |
val result = result_df.collect | |
val expected = Array( | |
Row(Row(1, "b1"), 2, 2, Map(11.0 -> 1, 21.0 -> 2), List(11.0, 21.0), 1), | |
Row(Row(2, "b2"), 1, 1, Map(31.0 -> 1), List(31.0), 1), | |
Row(Row(2, "b3"), 1, 1, Map(41.0 -> 1), List(41.0), 1) | |
) | |
assert(result === expected) | |
} | |
test("Test grouping calculation with maps and options") { | |
val result_df = data | |
.customGroupBy(x => (x.a, x.b)) | |
.agg( | |
GroupingCalculation("map_with_option", | |
init = () => Map[Option[String], Int](), | |
update = (_: TestCase1, accum: Map[Option[String], Int]) => { | |
accum + (None -> 1) | |
}, | |
evaluate = (accum: Map[Option[String], Int]) => accum.size | |
) | |
) | |
.sort('key("_1"), 'key("_2")) | |
println("DEBUGGING:") // TODO: REMOVE | |
result_df.explain() | |
val result = result_df.collect | |
val expected = Array( | |
Row(Row(1, "b1"), 1), | |
Row(Row(2, "b2"), 1), | |
Row(Row(2, "b3"), 1) | |
) | |
assert(result === expected) | |
} | |
test("Test grouping calculations with various combinations of case classes") { | |
val result_df = data | |
.customGroupBy(x => (x.a, x.b)) | |
.agg( | |
GroupingCalculation[TestCase1, TestCase2, TestCase3]("test", | |
init = () => TestCase2(TestCase3(0, 0.0), TestCase3(0, 0.0)), | |
update = (item: TestCase1, accum: TestCase2) => { | |
val TestCase2(TestCase3(x1, x2), TestCase3(x3, x4)) = accum | |
TestCase2( | |
TestCase3(1 + x1, item.d + x2), | |
TestCase3(3 + x3, -item.d + x4) | |
) | |
}, | |
evaluate = (accum: TestCase2) => accum.a | |
) | |
) | |
.sort('key("_1"), 'key("_2")) | |
.as[TestCase4] | |
println("DEBUGGING:") // TODO: REMOVE | |
result_df.explain() | |
val result = result_df.collect | |
val expected = Array( | |
TestCase4((1, "b1"), TestCase3(3, 53.0)), | |
TestCase4((2, "b2"), TestCase3(1, 31.0)), | |
TestCase4((2, "b3"), TestCase3(1, 41.0)) | |
) | |
assert(result === expected) | |
} | |
test("Test firstAndOnly") { | |
val result_df = data | |
.customGroupBy(x => (x.a, x.b)) | |
.agg( | |
GroupingCalculation.firstAndOnly[TestCase1, String]("B", _.b), | |
GroupingCalculation.firstAndOnly[TestCase1, Long]("C", _.c) | |
) | |
.sort('key("_1"), 'key("_2")) | |
println("DEBUGGING:") // TODO: REMOVE | |
result_df.explain() | |
val result = result_df.collect | |
val expected = Array( | |
Row(Row(1, "b1"), "b1", 1L), | |
Row(Row(2, "b2"), "b2", 1L), | |
Row(Row(2, "b3"), "b3", 1L) | |
) | |
assert(result === expected) | |
} | |
test("Test fromCountsBy") { | |
val result_df = data | |
.customGroupBy(x => (x.a, x.b)) | |
.agg( | |
GroupingCalculation.fromCountsBy[TestCase1, Double, Long]("test1", _.d) { x: Map[Double, Long] => x.size }, | |
GroupingCalculation.fromCountsBy[TestCase1, Double, Map[Double, Long]]("test2", _.d)(identity[Map[Double, Long]]) | |
) | |
.sort('key("_1"), 'key("_2")) | |
println("DEBUGGING:") // TODO: REMOVE | |
result_df.explain() | |
val result = result_df.collect | |
val expected = Array( | |
Row(Row(1, "b1"), 2L, Map(11.0 -> 1L, 21.0 -> 2L)), | |
Row(Row(2, "b2"), 1L, Map(31.0 -> 1L)), | |
Row(Row(2, "b3"), 1L, Map(41.0 -> 1L)) | |
) | |
assert(result === expected) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment