Skip to content

Instantly share code, notes, and snippets.

@Leonti
Last active April 2, 2019 22:29
Show Gist options
  • Save Leonti/e3af8472ac92dfe7cdbd689c11ecd03b to your computer and use it in GitHub Desktop.
Save Leonti/e3af8472ac92dfe7cdbd689c11ecd03b to your computer and use it in GitHub Desktop.
Generate case classes for Spark DataFrame from a schema
object SchemaToCaseClass {
import org.apache.spark.sql.types._
trait Field
case class FlatField(name: String, t: String, isNullable: Boolean) extends Field
case class CaseClass(name: String, fields: List[Field], isNullable: Boolean) extends Field
case class ArrayClass(name: String, field: Field, isNullable: Boolean) extends Field
case class PrintField(name: String, typeName: String, isOptional: Boolean)
case class PrintClass(name: String, fields: List[PrintField])
def schemaToCaseClass(schema: StructType, className: String): Unit = {
val result = toPrintClasses(Some(className), schema.map(field => tc(field.name, field.dataType, field.nullable)).toList).reverse.map(printClass).mkString("")
println("============================")
println(result)
println("============================")
}
def printClass(toPrint: PrintClass): String = {
val fields = toPrint.fields.map(f => {
if (f.isOptional)
s"${f.name}: Option[${f.typeName}]"
else
s"${f.name}: ${f.typeName}"
}).mkString(",\n ")
s"""
|case class ${toPrint.name}(
| $fields
|)
""".stripMargin
}
def toPrintClasses(className: Option[String], fields: List[Field]): List[PrintClass] = {
val result: List[(PrintField, List[PrintClass])] = fields.map({
case FlatField(name, t, isOptional) => (PrintField(name, t, isOptional), List())
case CaseClass(name, nestedFields, isOptional) => {
(PrintField(name, name.capitalize, isOptional), toPrintClasses(Some(name.capitalize), nestedFields))
}
case ArrayClass(name, nestedField, isOptional) => nestedField match {
case FlatField(name, t, isOptional) =>
(PrintField(name, s"Seq[$t]", isOptional), List())
case _ =>
(PrintField(name, s"Seq[${name.capitalize}]", isOptional), toPrintClasses(None, List(nestedField)))
}
})
val fieldsToPrint: List[PrintField] = result.map(_._1)
val classesToPrint = result.flatMap(_._2)
className match {
case Some(cn) => PrintClass(cn.capitalize, fieldsToPrint) :: classesToPrint
case None => classesToPrint
}
}
private def tc(name: String, dataType: DataType, isNullable: Boolean): Field = dataType match {
case struct:StructType => CaseClass(name, struct.fields.toList.map(f => tc(f.name, f.dataType, f.nullable)), isNullable)
case array:ArrayType => ArrayClass(name, tc(name, array.elementType, array.containsNull), isNullable)
case _ => FlatField(name, dataTypeToString(dataType), isNullable)
}
private def dataTypeToString(dt: DataType): String = dt match {
case _:ByteType => "Byte"
case _:ShortType => "Short"
case _:IntegerType => "Int"
case _:LongType => "Long"
case _:FloatType => "Float"
case _:DoubleType => "Double"
case _:DecimalType => "java.math.BigDecimal"
case _:StringType => "String"
case _:BinaryType => "Array[Byte]"
case _:BooleanType => "Boolean"
case _:TimestampType => "java.sql.Timestamp"
case _:DateType => "java.sql.Date"
case _:ArrayType => s"scala.collection.UnknownSeq"
case _:MapType => "scala.collection.Map"
case _:StructType => "org.apache.spark.sql.Row"
case _ => "String"
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment