Skip to content

Instantly share code, notes, and snippets.

@nfarah86
Created April 19, 2023 05:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nfarah86/d1984eb641d1fb82287ccfae9cb16a07 to your computer and use it in GitHub Desktop.
Save nfarah86/d1984eb641d1fb82287ccfae9cb16a07 to your computer and use it in GitHub Desktop.
import org.apache.spark.sql.catalyst.expressions.{Add, If, Literal}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, Dataset, Row}
import org.apache.spark.sql.SaveMode._
// initialize `employee_country` table with CDC enabled
val employeeCountryTablePath = "/tmp/hudi/employee_country"
val employeeCountrySchema = StructType(Seq(
StructField("employeeId", IntegerType),
StructField("country", StringType),
StructField("ts", StringType)
))
spark.createDataFrame(spark.sparkContext.emptyRDD[Row], employeeCountrySchema).
write.format("hudi").
option("hoodie.datasource.write.recordkey.field", "employeeId").
option("hoodie.datasource.write.precombine.field", "ts").
option("hoodie.table.name", "employee_country").
option("hoodie.table.cdc.enabled", "true").
mode(Overwrite).
save(employeeCountryTablePath)
// simulate input stream and write to `employee_country`
val inputStream = new MemoryStream[(Int, String, String)](100, spark.sqlContext)
inputStream.toDS().toDF("employeeId", "country", "ts").
writeStream.format("hudi").
foreachBatch { (batch: Dataset[Row], _: Long) =>
batch.write.format("hudi").
option("hoodie.datasource.write.recordkey.field", "employeeId").
option("hoodie.datasource.write.precombine.field", "ts").
mode(Append).
save(employeeCountryTablePath)
}.start()
// initialize `country_headcount` table
val countryHeadcountTablePath = "/tmp/hudi/country_headcount"
val countryHeadcountSchema = StructType(Seq(
StructField("country", StringType),
StructField("headcount", IntegerType),
StructField("ts", StringType)
))
spark.createDataFrame(spark.sparkContext.emptyRDD[Row], countryHeadcountSchema).
write.format("hudi").
option("hoodie.datasource.write.recordkey.field", "country").
option("hoodie.datasource.write.precombine.field", "ts").
option("hoodie.table.name", "country_headcount").
mode(Overwrite).
save(countryHeadcountTablePath)
// create a CDC processing stream to aggregate the changed data and update `country_headcount` table
spark.readStream.format("hudi").
option("hoodie.datasource.query.type", "incremental").
option("hoodie.datasource.query.incremental.format", "cdc").
load(employeeCountryTablePath).
writeStream.format("hudi").
foreachBatch { (batch: Dataset[Row], _: Long) =>
val current = spark.read.format("hudi").load(countryHeadcountTablePath)
batch.select(
// extract country from `before` and `after` fields
get_json_object(col("before"), "$.country").as("bf_country"),
get_json_object(col("after"), "$.country").as("af_country"),
get_json_object(col("after"), "$.ts").as("ts")
).
// if record in `before`, -1 to that country's headcount
withColumn("bf_ct", new Column(If(isnull(col("bf_country")).expr, typedLit(0).expr, typedLit(-1).expr))).
// if record in `after`, +1 to that country's headcount
withColumn("af_ct", new Column(If(isnull(col("af_country")).expr, typedLit(0).expr, typedLit(1).expr))).
select(explode(array(Array(
struct(col("bf_country").as("country"), col("bf_ct").as("ct"), col("ts")),
struct(col("af_country").as("country"), col("af_ct").as("ct"), col("ts"))): _*))).
select(col("col.country").as("country"), col("col.ct").as("ct"), col("col.ts").as("ts")).
where("country is not null").
groupBy("country").
agg("ct" -> "sum", "ts" -> "max").
// update the current headcount values
join(current, Seq("country"), "left").
select(
col("country"),
new Column(Add(col("sum(ct)").expr, If(isnull(col("headcount")).expr, Literal(0), col("headcount").expr))).as("headcount"),
col("max(ts)").as("ts")
).
write.format("hudi").
option("hoodie.datasource.write.recordkey.field", "country").
option("hoodie.datasource.write.precombine.field", "ts").
mode(Append).
save(countryHeadcountTablePath)
}.start()
// simulate input data
inputStream.addData(Seq((1, "US", "1000"), (2, "IN", "1000"), (3, "CN", "1000")))
inputStream.addData(Seq((4, "US", "1100"), (5, "US", "1100"), (6, "IN", "1100"), (7, "CN", "1100")))
inputStream.addData(Seq((4, "SG", "1200")))
// read the latest country_headcount data
spark.read.format("hudi").load(countryHeadcountTablePath).
select("country", "headcount", "ts").show(false)
/*
+-------+---------+----+
|country|headcount|ts |
+-------+---------+----+
|US |2 |1200|
|IN |2 |1100|
|CN |2 |1100|
|SG |1 |1200|
+-------+---------+----+
Note that the result may not be immediately up-to-date given the streaming processing is async.
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment