Skip to content

Instantly share code, notes, and snippets.

@anish749
Created October 22, 2017 10:57
Show Gist options
  • Save anish749/d0298390a08f6b0500bc5b21b0a7f7df to your computer and use it in GitHub Desktop.
Save anish749/d0298390a08f6b0500bc5b21b0a7f7df to your computer and use it in GitHub Desktop.
Utility functions for extending Spark Datasets for exploring partitions
package org.anish.spark.skew
import java.io.File
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{Dataset, SaveMode}
/**
* Few Utility functions for extending Spark Datasets for exploring partitions
* Created by anish on 22/10/17.
*/
object Utils {
implicit class DataSetUtils[T](val dataset: Dataset[T]) {
/**
* Prints record counts per partition
*/
def showCountPerPartition() = {
println(countPerPartition.map(x => s"${x._1} => ${x._2}").mkString("Idx => Cnt\n", "\n", ""))
}
/**
* Prints total partitions, records in an RDD
* Counts values in each partitions and prints 4 important percentile counts
*/
def showPartitionStats(extended: Boolean = false) = {
val numPartitions = countPerPartition.length
val sortedCounts = countPerPartition.map(_._2).sorted
def percentileIndex(p: Int) = math.ceil((numPartitions - 1) * (p / 100.0)).toInt
println(s"Total Partitions -> $numPartitions\n" +
s"Total Records -> ${sortedCounts.map(_.toLong).sum}\n" + // One partition wont have records more than Int.MAX_VALUE
s"Percentiles -> Min \t| 25th \t| 50th \t| 75th \t| Max\n" +
s"Percentiles -> ${sortedCounts(percentileIndex(0))} \t| " +
s"${sortedCounts(percentileIndex(25))} \t| " +
s"${sortedCounts(percentileIndex(50))} \t| ${sortedCounts(percentileIndex(75))} \t| " +
s"${sortedCounts(percentileIndex(100))}")
if (extended) showCountPerPartition()
}
/**
* Counts number of records per partition. Triggers an action
*
* @return List of tuple with partition index and count of records
*/
lazy val countPerPartition: List[(Int, Int)] = { // Because the data set is immutable, we dont want to count multiple times
dataset.rdd.mapPartitionsWithIndex { (index, iter) =>
List((index, iter.size)).iterator
}.collect.toList
}
// I know this is silly for Spark, better check the Spark UI and see the time required for each stage
def timedSaveToDisk(operationName: String, tmpFilepath: String = s"data/tmp/${System.currentTimeMillis()}") = {
time(operationName) {
dataset.write.mode(SaveMode.Overwrite)
.save(tmpFilepath)
}
FileUtils.deleteDirectory(new File(tmpFilepath))
def time[R](blockName: String)(block: => R): R = {
val t0 = System.nanoTime()
val result = block // call-by-name
val timeElapsedNano = System.nanoTime() - t0
println(s"Elapsed time for $blockName : $timeElapsedNano ns or ${
timeElapsedNano / 1e6
} ms")
result
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment