Skip to content

Instantly share code, notes, and snippets.

@GrigorievNick
Last active August 4, 2021 09:49
Show Gist options
  • Save GrigorievNick/9adffa7f0b551bdee34118050c5491c2 to your computer and use it in GitHub Desktop.
Save GrigorievNick/9adffa7f0b551bdee34118050c5491c2 to your computer and use it in GitHub Desktop.
Merge records in two dataframes by id columns
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.scalatest.FunSuite
import java.sql.Timestamp
import java.time.LocalDateTime
case class MergedRecord(id: Long, id2: Long, data: String, ts: Timestamp, hdl: String)
class MergeDataframe extends FunSuite {
private val initialTimestamp = LocalDateTime.now().minusDays(1)
implicit val sparkSession: SparkSession = SparkSession.builder()
.config("spark.sql.adaptive.enabled", "true") // with out it merge info genera too many files
.config("spark.sql.adaptive.coalescePartitions.enabled", "true")
.master("local")
.getOrCreate()
import sparkSession.implicits._
test("merge by key") {
val idCol = "id"
val idCol2 = "id2"
val leftData = (0 until 100).map(id => (id, id + 100, "data1", Timestamp.valueOf(initialTimestamp)))
val rightData = (0 until 100).map(id => (id, id + 100, "hdlData"))
val leftDf = leftData.toDF(idCol, idCol2, "data", "ts")
val rightDf = rightData.toDF(idCol, idCol2, "hdl")
val mergedRecords = leftDf
.mergeBy(rightDf, List(idCol, idCol2), Some(3))
.sort(idCol)
.as[MergedRecord]
mergedRecords.show(numRows = 100, truncate = false)
val expected = leftData.map(entry => MergedRecord(entry._1, entry._2, entry._3, entry._4, "hdlData")).toArray
val result = mergedRecords.collect()
result.zip(expected).foreach { case (l, r) => assert(l == r) }
}
implicit class DataFrameUnionUtils(df: DataFrame) {
private val sourceTypeColumn = "type"
def schemaByName(colName: String): StructField = df.schema.fields(df.schema.fieldIndex(colName))
def mergeBy(rightDf: DataFrame, keys: Seq[String], numPartitions: Option[Int] = None): DataFrame = {
// TODO Implement for case when different number of rows, treat no row as null
assert(
rightDf.schema.fields.take(keys.size).sameElements(df.schema.fields.take(keys.size)),
"key fields must have same type and order"
)
val keyFields = df.schema.fields.take(keys.size).toList
val dfFields = df.schema.fields.drop(keys.size).toList
val rightDfFields = rightDf.schema.fields.drop(keys.size).toList
val mergedSchema = StructType(keyFields ::: dfFields ::: rightDfFields)
def alignSchema(df: DataFrame, missingFields: Seq[StructField]) =
missingFields.foldLeft(df)((df, field) => df.withColumn(field.name, lit(null)))
.select(mergedSchema.map(f => col(f.name)): _*)
val dfWithNullForRight = alignSchema(df, rightDfFields).withColumn(sourceTypeColumn, lit(1))
val rightDfWithNullForDf = alignSchema(rightDf, dfFields).withColumn(sourceTypeColumn, lit(2))
val dfColumnsByIndex = (keyFields ::: dfFields).map(f => dfWithNullForRight.schema.fieldIndex(f.name))
val rightDfByIndex = rightDfFields.map(_.name).map(rightDfWithNullForDf.schema.fieldIndex)
dfWithNullForRight
.union(rightDfWithNullForDf)
.transform(t =>
numPartitions
.map(num => t.repartition(num, keys.map(col): _*))
.getOrElse(t.repartition(keys.map(col): _*))
).sortWithinPartitions((keys.map(col) :+ col(sourceTypeColumn)): _*)
.mapPartitions { it =>
it
.grouped(2)
.map { case Seq(r, l) => Row.fromSeq(dfColumnsByIndex.map(r.get) ::: rightDfByIndex.map(l.get)) }
}(RowEncoder(mergedSchema))
}
}
}
implicit class SliceBySubsequence[T, K](it: Iterator[T]) extends Serializable {
/**
* @param key – The function that used to extract key from iterator entry.
* @return - An iterator returning of subsequences(Iterator) with same key.
*
* Note: Reuse: After calling this method, one should discard the iterator it was called on,
* and use only the iterator that was returned. Using the old iterator is undefined, subject to change,
* and may result in changes to the new iterator as well.
*/
def sliceBy(key: T => K): Iterator[Iterator[T]] = new AbstractIterator[Iterator[T]] {
private var bufferedIt = it.buffered
def hasNext: Boolean = bufferedIt.hasNext
def next(): Iterator[T] =
bufferedIt.headOption match {
case Some(hd) =>
val (subsequence, rest) = bufferedIt.span(r => key(r) == key(hd))
bufferedIt = rest.buffered
subsequence
case None =>
Iterator.empty
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment