Skip to content

Instantly share code, notes, and snippets.

@GrigorievNick
Last active September 7, 2021 15:13
Show Gist options
  • Save GrigorievNick/97d94f2e503ca3892520ac414bad0c73 to your computer and use it in GitHub Desktop.
Save GrigorievNick/97d94f2e503ca3892520ac414bad0c73 to your computer and use it in GitHub Desktop.
Spark Create unique sequential id per spark partition
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.LeafExpression
import org.apache.spark.sql.catalyst.expressions.Stateful
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.LongType
object SparkSQLGeneratePartitionOffset {
def main(args: Array[String]): Unit = {
implicit val spark: SparkSession = SparkSession.builder().master("local").getOrCreate()
val randomDF = spark
.range(0, 50, 1, 5)
.withColumn("id", rand(50) * 10)
spark.sparkContext.setJobDescription("Global sort")
randomDF
.withColumn("part_id", spark_partition_id())
.withColumn("generated_id", row_number().over(Window.partitionBy("part_id").orderBy("id")))
.show(100)
spark.sparkContext.setJobDescription("custom zipWithIndex")
val df = randomDF.withColumn("part_id", spark_partition_id())
df
.mapPartitions { it =>
it
.zipWithIndex
.map { case (r, index) => Row.fromSeq(r.toSeq :+ index.toLong) }
}(RowEncoder(df.schema.add("generated_id", LongType)))
.show(100)
spark.sparkContext.setJobDescription("custom sql function")
randomDF
.withColumn("part_id", spark_partition_id())
.withColumn("generated_id", new Column(SparkPartitionOffset()))
.show(100)
Thread.sleep(1000000000)
}
case class SparkPartitionOffset() extends LeafExpression with Stateful {
/**
* From org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID
*
* Record ID within each partition. By being transient, count's value is reset to 0 every time
* we serialize and deserialize and initialize it.
*/
/**
* Record ID within each partition. By being transient, count's value is reset to 0 every time
* we serialize and deserialize and initialize it.
*/
@transient private[this] var count: Long = _
override protected def initializeInternal(partitionIndex: Int): Unit = count = 0L
override def nullable: Boolean = false
override def dataType: DataType = LongType
override protected def evalInternal(input: InternalRow): Long = {
val currentCount = count
count += 1
currentCount
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count")
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ev.copy(
code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $countTerm;
$countTerm++;""",
isNull = FalseLiteral
)
}
override def prettyName: String = "spark_partition_offset"
override def sql: String = s"$prettyName()"
override def freshCopy(): SparkPartitionOffset = SparkPartitionOffset()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment