Created
October 22, 2017 10:57
-
-
Save anish749/d0298390a08f6b0500bc5b21b0a7f7df to your computer and use it in GitHub Desktop.
Utility functions for extending Spark Datasets for exploring partitions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.anish.spark.skew | |
import java.io.File | |
import org.apache.commons.io.FileUtils | |
import org.apache.spark.sql.{Dataset, SaveMode} | |
/** | |
* Few Utility functions for extending Spark Datasets for exploring partitions | |
* Created by anish on 22/10/17. | |
*/ | |
object Utils { | |
implicit class DataSetUtils[T](val dataset: Dataset[T]) { | |
/** | |
* Prints record counts per partition | |
*/ | |
def showCountPerPartition() = { | |
println(countPerPartition.map(x => s"${x._1} => ${x._2}").mkString("Idx => Cnt\n", "\n", "")) | |
} | |
/** | |
* Prints total partitions, records in an RDD | |
* Counts values in each partitions and prints 4 important percentile counts | |
*/ | |
def showPartitionStats(extended: Boolean = false) = { | |
val numPartitions = countPerPartition.length | |
val sortedCounts = countPerPartition.map(_._2).sorted | |
def percentileIndex(p: Int) = math.ceil((numPartitions - 1) * (p / 100.0)).toInt | |
println(s"Total Partitions -> $numPartitions\n" + | |
s"Total Records -> ${sortedCounts.map(_.toLong).sum}\n" + // One partition wont have records more than Int.MAX_VALUE | |
s"Percentiles -> Min \t| 25th \t| 50th \t| 75th \t| Max\n" + | |
s"Percentiles -> ${sortedCounts(percentileIndex(0))} \t| " + | |
s"${sortedCounts(percentileIndex(25))} \t| " + | |
s"${sortedCounts(percentileIndex(50))} \t| ${sortedCounts(percentileIndex(75))} \t| " + | |
s"${sortedCounts(percentileIndex(100))}") | |
if (extended) showCountPerPartition() | |
} | |
/** | |
* Counts number of records per partition. Triggers an action | |
* | |
* @return List of tuple with partition index and count of records | |
*/ | |
lazy val countPerPartition: List[(Int, Int)] = { // Because the data set is immutable, we dont want to count multiple times | |
dataset.rdd.mapPartitionsWithIndex { (index, iter) => | |
List((index, iter.size)).iterator | |
}.collect.toList | |
} | |
// I know this is silly for Spark, better check the Spark UI and see the time required for each stage | |
def timedSaveToDisk(operationName: String, tmpFilepath: String = s"data/tmp/${System.currentTimeMillis()}") = { | |
time(operationName) { | |
dataset.write.mode(SaveMode.Overwrite) | |
.save(tmpFilepath) | |
} | |
FileUtils.deleteDirectory(new File(tmpFilepath)) | |
def time[R](blockName: String)(block: => R): R = { | |
val t0 = System.nanoTime() | |
val result = block // call-by-name | |
val timeElapsedNano = System.nanoTime() - t0 | |
println(s"Elapsed time for $blockName : $timeElapsedNano ns or ${ | |
timeElapsedNano / 1e6 | |
} ms") | |
result | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment