Last active
April 2, 2019 22:29
-
-
Save Leonti/e3af8472ac92dfe7cdbd689c11ecd03b to your computer and use it in GitHub Desktop.
Generate case classes for Spark DataFrame from a schema
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object SchemaToCaseClass { | |
import org.apache.spark.sql.types._ | |
trait Field | |
case class FlatField(name: String, t: String, isNullable: Boolean) extends Field | |
case class CaseClass(name: String, fields: List[Field], isNullable: Boolean) extends Field | |
case class ArrayClass(name: String, field: Field, isNullable: Boolean) extends Field | |
case class PrintField(name: String, typeName: String, isOptional: Boolean) | |
case class PrintClass(name: String, fields: List[PrintField]) | |
def schemaToCaseClass(schema: StructType, className: String): Unit = { | |
val result = toPrintClasses(Some(className), schema.map(field => tc(field.name, field.dataType, field.nullable)).toList).reverse.map(printClass).mkString("") | |
println("============================") | |
println(result) | |
println("============================") | |
} | |
def printClass(toPrint: PrintClass): String = { | |
val fields = toPrint.fields.map(f => { | |
if (f.isOptional) | |
s"${f.name}: Option[${f.typeName}]" | |
else | |
s"${f.name}: ${f.typeName}" | |
}).mkString(",\n ") | |
s""" | |
|case class ${toPrint.name}( | |
| $fields | |
|) | |
""".stripMargin | |
} | |
def toPrintClasses(className: Option[String], fields: List[Field]): List[PrintClass] = { | |
val result: List[(PrintField, List[PrintClass])] = fields.map({ | |
case FlatField(name, t, isOptional) => (PrintField(name, t, isOptional), List()) | |
case CaseClass(name, nestedFields, isOptional) => { | |
(PrintField(name, name.capitalize, isOptional), toPrintClasses(Some(name.capitalize), nestedFields)) | |
} | |
case ArrayClass(name, nestedField, isOptional) => nestedField match { | |
case FlatField(name, t, isOptional) => | |
(PrintField(name, s"Seq[$t]", isOptional), List()) | |
case _ => | |
(PrintField(name, s"Seq[${name.capitalize}]", isOptional), toPrintClasses(None, List(nestedField))) | |
} | |
}) | |
val fieldsToPrint: List[PrintField] = result.map(_._1) | |
val classesToPrint = result.flatMap(_._2) | |
className match { | |
case Some(cn) => PrintClass(cn.capitalize, fieldsToPrint) :: classesToPrint | |
case None => classesToPrint | |
} | |
} | |
private def tc(name: String, dataType: DataType, isNullable: Boolean): Field = dataType match { | |
case struct:StructType => CaseClass(name, struct.fields.toList.map(f => tc(f.name, f.dataType, f.nullable)), isNullable) | |
case array:ArrayType => ArrayClass(name, tc(name, array.elementType, array.containsNull), isNullable) | |
case _ => FlatField(name, dataTypeToString(dataType), isNullable) | |
} | |
private def dataTypeToString(dt: DataType): String = dt match { | |
case _:ByteType => "Byte" | |
case _:ShortType => "Short" | |
case _:IntegerType => "Int" | |
case _:LongType => "Long" | |
case _:FloatType => "Float" | |
case _:DoubleType => "Double" | |
case _:DecimalType => "java.math.BigDecimal" | |
case _:StringType => "String" | |
case _:BinaryType => "Array[Byte]" | |
case _:BooleanType => "Boolean" | |
case _:TimestampType => "java.sql.Timestamp" | |
case _:DateType => "java.sql.Date" | |
case _:ArrayType => s"scala.collection.UnknownSeq" | |
case _:MapType => "scala.collection.Map" | |
case _:StructType => "org.apache.spark.sql.Row" | |
case _ => "String" | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment