Last active
June 27, 2019 09:08
-
-
Save qi-qi/3138ff1a7adbe673e7ddb93f147c64a9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
test("hello-rdd") { | |
// sample input | |
val df = Seq( | |
("a", 0, 200), | |
("a", 1000, 2000), | |
("a", 150, 160), | |
("b", 0, 2), | |
("b", 2, 8), | |
("b", 5, 15), | |
("c", 5, 15), | |
("c", 5, 15), | |
("d", 0, 0) | |
).toDF("listen_group_id", "range_legit_from", "range_legit_to") | |
df.show() | |
// results | |
df.select($"listen_group_id", $"range_legit_from", $"range_legit_to") | |
.map(r => (r.getString(0), r.getInt(1), r.getInt(2))) | |
.rdd | |
.keyBy(x => x._1) | |
.mapValues(x => BitSet(x._2 until x._3: _*)) | |
.reduceByKey(_ ++ _) | |
.mapValues(_.size) | |
.toDF("listen_group_id", "listen_group_bytes_sum_unique") | |
.show() | |
/** | |
* Sample Input: | |
* +---------------+----------------+--------------+ | |
* |listen_group_id|range_legit_from|range_legit_to| | |
* +---------------+----------------+--------------+ | |
* | a| 0| 200| | |
* | a| 1000| 2000| | |
* | a| 150| 160| | |
* | b| 0| 2| | |
* | b| 2| 8| | |
* | b| 5| 15| | |
* | c| 5| 15| | |
* | c| 5| 15| | |
* | d| 0| 0| | |
* +---------------+----------------+--------------+ | |
* | |
* Results: | |
* +-----+-----------------------------+ | |
* |group|listen_group_bytes_sum_unique| | |
* +-----+-----------------------------+ | |
* | a| 1200| | |
* | b| 15| | |
* | c| 10| | |
* | d| 0| | |
* +-----+-----------------------------+ | |
* */ | |
} | |
//There are other ways like using `udaf over window` (a bit more efficient but not clean) => the below one to use `udf + collect_list + struct` is cleaner | |
test("hello-udf") { | |
// these import can be moved to top => demo purpose to put here | |
import org.apache.spark.sql.functions.{struct, collect_list, udf} | |
import org.apache.spark.sql.expressions.Window | |
import scala.collection.BitSet | |
// sample input | |
val df = Seq( | |
(0, "a", 0, 200), | |
(1, "a", 1000, 2000), | |
(2, "a", 150, 160), | |
(3, "b", 0, 10), | |
(4, "b", 2, 8), | |
(5, "b", 5, 15) | |
).toDF("id", "group", "from", "to") | |
val window = Window.partitionBy($"group").orderBy($"id") | |
// === udf: Algorithm === | |
// alternative using curly braces: data.map { case Row(from: Int, to: Int) => BitSet(from until to: _*) } | |
val findUniqueBytesUDF = udf { data: Seq[Row] => | |
data | |
.map(x => BitSet(x.getAs[Int]("from") until x.getAs[Int]("to"): _*)) | |
.reduce(_ ++ _) | |
.size | |
} | |
//helpers | |
val ranges = collect_list(struct($"from", $"to")) | |
val uniqueBytesCumSum = findUniqueBytesUDF(ranges.over(window.rowsBetween(Window.unboundedPreceding, Window.currentRow))) | |
val uniqueBytesTotalSum = findUniqueBytesUDF(ranges.over(window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))) | |
//results | |
df.withColumn("unique_bytes_cum_sum", uniqueBytesCumSum) | |
.withColumn("unique_bytes_total_sum", uniqueBytesTotalSum) | |
.show() | |
} | |
/** | |
* Results: | |
* | |
* +---+-----+----+----+--------------------+----------------------+ | |
* | id|group|from| to|unique_bytes_cum_sum|unique_bytes_total_sum| | |
* +---+-----+----+----+--------------------+----------------------+ | |
* | 3| b| 0| 10| 10| 15| | |
* | 4| b| 2| 8| 10| 15| | |
* | 5| b| 5| 15| 15| 15| | |
* | 0| a| 0| 200| 200| 1200| | |
* | 1| a|1000|2000| 1200| 1200| | |
* | 2| a| 150| 160| 1200| 1200| | |
* +---+-----+----+----+--------------------+----------------------+ | |
**/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment