Skip to content

Instantly share code, notes, and snippets.

@wangjingke
Created April 24, 2018 17:39
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save wangjingke/7df2af851cab5d9e52d0fe439239389d to your computer and use it in GitHub Desktop.
Save wangjingke/7df2af851cab5d9e52d0fe439239389d to your computer and use it in GitHub Desktop.
some functions helping import data from Hadoop to spark
object sparkHelpers extends Serializable {
// function to ingest from Hadoop and convert to Spark dataframe
def readHadoopToSparkDF(sc: org.apache.spark.SparkContext, sqlContext: org.apache.spark.sql.SQLContext, hdfs_path: String, schema: List[org.apache.spark.sql.types.DataType], sep: String = "\t", cols: Array[String] = Array()): org.apache.spark.sql.DataFrame = {
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
val rdd = sc.textFile(hdfs_path)
val header = if (cols.length == 0) rdd.first.split(sep).map(_.trim) else cols
val body = if (cols.length == 0) rdd.filter(row => row != header) else rdd
val df_schema_list = (header, schema, List.fill(schema.length)(true)).zipped.toList
val df_schema = StructType(df_schema_list.map(x => StructField(x._1, x._2, x._3)))
// function to infer row type based on schema
def inferRowType (elm_type: List[org.apache.spark.sql.types.DataType])(elm: Array[String]) = {
def inferElmType (elm_type: org.apache.spark.sql.types.DataType, elm: String) = {
var out: Any = null
if (elm == "") {
out = null
} else {
out = elm_type match {
case StringType => elm.trim
case IntegerType => elm.trim.toInt
case DoubleType => elm.trim.toDouble
case DateType => java.sql.Date.valueOf(elm.trim) // better convert to the YYYY-MM-dd format first
case TimestampType => java.sql.Timestamp.valueOf(elm.trim)
}
}
out // must specify output to the outer environment
}
elm_type.zip(elm).map(x => inferElmType(x._1, x._2))
}
val rdd_type = inferRowType(schema)(_)
val df_row = body.map(_.split(sep)).map(r => Row.fromSeq(rdd_type(r).toSeq))
//val df_row = body.map(_.split(sep)).map(r => Row.fromSeq((r(0).trim.toInt +: r.tail.map(_.trim)).toSeq))
sqlContext.createDataFrame(df_row, df_schema)
}
def constructSchema(schema: List[(String, Int)]) = {
val len = schema.length
// val output = schema.map(x => List.fill(x._2)(x._1)).reduce(_ ++ _)
val output = schema.flatMap(x => List.fill(x._2)(typeMatch(x._1)))
output
}
def dropColumns(df: org.apache.spark.sql.DataFrame, cols: List[String]) = {
// function to drop spark dataframe columns in bulks, which is available after spark2.x
import org.apache.spark.sql.Column
// df.select(df.columns.filter(colName => !cols.contains(colName)).map(x => new Column(x)): _*) // :_* appended for the parameter to be considered an argument sequence
df.select(df.columns.diff(cols).map(x => new Column(x)): _*)
}
def typeMatch(input: String) = {
import org.apache.spark.sql.types._
input match {
case "Int" => IntegerType
case "String" => StringType
case "Date" => DateType
case "Double" => DoubleType
case "Timestamp" => TimestampType
}
}
def changeColType(df: org.apache.spark.sql.DataFrame, col: String, newType: String) = {
df.withColumn(col, df(col).cast(typeMatch(newType)))
}
/*
def changeMulColType(df: org.apache.spark.sql.DataFrame, colName: List[String], newType: List[String]): Either[String, org.apache.spark.sql.DataFrame] = {
// use either to handle errors, so need to use Right(x).right.get to retrieve the df
if (newType.length > 1 && newType.length != colName.length) {
Left("Column name and type have different lengths")
} else {
val types = if (newType.length == 1) List.fill(colName.length)(newType.head) else newType
Right(
colName.zip(types).foldLeft(df) {
(table, zipped_col_type) => changeColType(table, zipped_col_type._1, zipped_col_type._2)
}
)
}
}
*/
def changeMulColType(df: org.apache.spark.sql.DataFrame, colName: List[String], newType: List[String]) = {
val types = if (newType.length == 1) List.fill(colName.length)(newType.head) else newType
colName.zip(types).foldLeft(df) {
(table, zipped_col_type) => changeColType(table, zipped_col_type._1, zipped_col_type._2)
}
}
def changeAllColType(df: org.apache.spark.sql.DataFrame, sourceType: String, newType: String) = {
df.schema.filter(_.dataType == typeMatch(sourceType)).foldLeft(df) {
// keyword case is optional
case (table, col) => changeColType(table, col.name, newType)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment