Skip to content

Instantly share code, notes, and snippets.

Created January 20, 2017 07:36
Show Gist options
  • Save yoyama/ce83f688717719fc8ca145c3b3ff43fd to your computer and use it in GitHub Desktop.
Save yoyama/ce83f688717719fc8ca145c3b3ff43fd to your computer and use it in GitHub Desktop.
Generate case class from spark DataFrame/Dataset schema.
* Generate Case class from DataFrame.schema
* val df:DataFrame = ...
* val s2cc = new Schema2CaseClass
* import s2cc.implicit._
* println(s2cc.schemaToCaseClass(df.schema, "MyClass"))
import org.apache.spark.sql.types._
class Schema2CaseClass {
type TypeConverter = (DataType) => String
def schemaToCaseClass(schema:StructType, className:String)(implicit tc:TypeConverter):String = {
def genField(s:StructField):String = {
val f = tc(s.dataType)
s match {
case x if(x.nullable) => s" ${}:Option[$f]"
case _ => s" ${}:$f"
val fieldsStr =",\n ")
|case class $className (
| $fieldsStr
object implicits {
implicit val defaultTypeConverter:TypeConverter = (t:DataType) => { t match {
case _:ByteType => "Byte"
case _:ShortType => "Short"
case _:IntegerType => "Int"
case _:LongType => "Long"
case _:FloatType => "Float"
case _:DoubleType => "Double"
case _:DecimalType => "java.math.BigDecimal"
case _:StringType => "String"
case _:BinaryType => "Array[Byte]"
case _:BooleanType => "Boolean"
case _:TimestampType => "java.sql.Timestamp"
case _:DateType => "java.sql.Date"
case _:ArrayType => "scala.collection.Seq"
case _:MapType => "scala.collection.Map"
case _:StructType => "org.apache.spark.sql.Row"
case _ => "String"
Copy link

RayTsui commented Apr 6, 2018

After getting the string of executable code of case class, how to execute the string? Scala reflect or something else?

Copy link

@RayTsui you write the string to a file or append to file of case classes.

Copy link

gstaubli commented May 3, 2018

I love this conversion process! I've had very nested schemas, which required me to manually run this code on different levels of nesting. Would love a recursive or other version to handle nested schemas, which I hope to contribute back unless someone beats me to it ;)

Copy link

I concur with @gstaubli. Can you please share what you did for the nested schema?

Copy link

skoppar commented Jun 21, 2018

Very good idea. I was also looking for some way to execute the case class creation method and found this:
import scala.reflect.runtime.universe._
import scala.reflect.runtime.currentMirror

val df = ....
val toolbox = currentMirror.mkToolBox()
val case_class = toolbox.compile(f.schemaToCaseClass(dfschema, "YourName"))

The return type of schemaToCaseClass would have to be runtime.universe.Tree and we would use Quasiquotes

def schemaToCaseClass(schema:StructType, className:String)(implicit tc:TypeConverter) :runtime.universe.Tree= {
def genField(s:StructField):String = {
val f = tc(s.dataType)
s match {
case x if(x.nullable) => s" ${}:Option[$f]"
case _ => s" ${}:$f"

  val fieldsStr =",\n  ")
     case class $className (

However, I was trying to apply it back to resulting dataframe and I dont see a way to do that. Sharing whatever I found, in case it helps someone
Reference -

Copy link

geoHeil commented Jun 26, 2018

Would it be possible to create a Macro? I can't seem to be able to actually make use of the class string generated as it won't compile

Copy link

BioQwer commented Nov 15, 2018

scala> import org.apache.spark.sql.types._
import org.apache.spark.sql.types._


scala> class Schema2CaseClass {
     |   type TypeConverter = (DataType) => String
     |   def schemaToCaseClass(schema:StructType, className:String)(implicit tc:TypeConverter):String = {
     |     def genField(s:StructField):String = {
     |       val f = tc(s.dataType)
     |       s match {
     |         case x if(x.nullable) => s"  ${}:Option[$f]"
     |         case _ => s"  ${}:$f"
     |       }
     |     }
     |     val fieldsStr =",\n  ")
     |     s"""
     |        |case class $className (
     |        |  $fieldsStr
     |        |)
     |   """.stripMargin
     |   }
     |   object implicits {
     |     implicit val defaultTypeConverter:TypeConverter = (t:DataType) => { t match {
     |       case _:ByteType => "Byte"
     |       case _:ShortType => "Short"
     |       case _:IntegerType => "Int"
     |       case _:LongType => "Long"
     |       case _:FloatType => "Float"
     |       case _:DoubleType => "Double"
     |       case _:DecimalType => "java.math.BigDecimal"
     |       case _:StringType => "String"
     |       case _:BinaryType => "Array[Byte]"
     |       case _:BooleanType => "Boolean"
     |       case _:TimestampType => "java.sql.Timestamp"
     |       case _:DateType => "java.sql.Date"
     |       case _:ArrayType => "scala.collection.Seq"
     |       case _:MapType => "scala.collection.Map"
     |       case _:StructType => "org.apache.spark.sql.Row"
     |       case _ => "String"
     |     }}
     |   }
     | }
<console>:12: error: not found: type DataType
  type TypeConverter = (DataType) => String
<console>:14: error: not found: type StructType
  def schemaToCaseClass(schema:StructType, className:String)(implicit tc:TypeConverter):String = {
<console>:15: error: not found: type StructField
    def genField(s:StructField):String = {
<console>:32: error: not found: type DataType
    implicit val defaultTypeConverter:TypeConverter = (t:DataType) => { t match {
<console>:33: error: not found: type ByteType
      case _:ByteType => "Byte"
<console>:34: error: not found: type ShortType
      case _:ShortType => "Short"
<console>:35: error: not found: type IntegerType
      case _:IntegerType => "Int"
<console>:36: error: not found: type LongType
      case _:LongType => "Long"
<console>:37: error: not found: type FloatType
      case _:FloatType => "Float"
<console>:38: error: not found: type DoubleType
      case _:DoubleType => "Double"
<console>:39: error: not found: type DecimalType
      case _:DecimalType => "java.math.BigDecimal"
<console>:40: error: not found: type StringType
      case _:StringType => "String"
<console>:41: error: not found: type BinaryType
      case _:BinaryType => "Array[Byte]"
<console>:42: error: not found: type BooleanType
      case _:BooleanType => "Boolean"
<console>:43: error: not found: type TimestampType
      case _:TimestampType => "java.sql.Timestamp"
<console>:44: error: not found: type DateType
      case _:DateType => "java.sql.Date"
<console>:45: error: not found: type ArrayType
      case _:ArrayType => "scala.collection.Seq"
<console>:46: error: not found: type MapType
      case _:MapType => "scala.collection.Map"
<console>:47: error: not found: type StructType
      case _:StructType => "org.apache.spark.sql.Row"

If you have this mistake use in console :paste, in this way it's works for me.

Copy link

BioQwer commented Nov 15, 2018

I concur with @gstaubli. Can you please share what you did for the nested schema?


Copy link

eschombu commented Aug 7, 2019

This is really helpful, but here's an improvement. In defaultTypeConverter, change the ArrayType case to

      case _: ArrayType => {
        val e = t match { case ArrayType(elementType, _) => elementType }

Copy link

zpappa commented Apr 26, 2020

Perhaps I did something wrong here, but was unable to get this working with the implicits statement
I had to explicitly pass it in as below, for anyone who had the same issue and received an identifier expected but 'implicit' found.

println(s2cc.schemaToCaseClass(schema, "MyclassName")(s2cc.implicits.defaultTypeConverter))

That said, super helpful, thanks!

Copy link

JituS commented Apr 13, 2021

Thanks for sharing this. As small enhancement could be, if there are nested StructType in a schema. I have tried incorporating that scenario below:


import org.apache.spark.sql.types._

class SchemaToCaseClassWriter(fileWriter: FileWriter) {
  type TypeConverter = DataType => String

  def write(schema: StructType, className: String): Unit = {
    run(schema, className)

  private def run(schema: StructType, className: String): Unit = {
    def genField(field: StructField): String = {
      val converter = defaultTypeConverter(
      val dataType = converter(field.dataType)
      field match {
        case x if x.nullable => s"  ${}:Option[$dataType]"
        case _ => s"  ${}:$dataType"

    val fieldsStr =",\n  ")
    val schemaClass =
      s"""case class $className (
         |  $fieldsStr

  private def defaultTypeConverter(colName: String): TypeConverter = {
    val converter: TypeConverter = {
      case _: ByteType => "Byte"
      case _: ShortType => "Short"
      case _: IntegerType => "Int"
      case _: LongType => "Long"
      case _: FloatType => "Float"
      case _: DoubleType => "Double"
      case _: DecimalType => "java.math.BigDecimal"
      case _: StringType => "String"
      case _: BinaryType => "Array[Byte]"
      case _: BooleanType => "Boolean"
      case _: TimestampType => "java.sql.Timestamp"
      case _: DateType => "java.sql.Date"
      case t: ArrayType =>
        val e = t match {
          case ArrayType(elementType, _) => elementType
      case _: MapType => "scala.collection.Map"
      case t: StructType =>
        run(t, colName.capitalize)
      case _ => "String"

Copy link

Schema with nested structure is having a struct with the same name at different levels, then 2 class with the same name will be created. This will break the schema when used. I think we need to use package name to handle that . Any other alternatives ?

Copy link

How to use the resultant string as case class any example pls

Copy link

7873737376 commented Feb 8, 2024

/*we can use below code directly it will return string instead of writing into file we can get string in a variable */

import org.apache.spark.sql.types._

class SchemaToCaseClassWriter {
type TypeConverter = DataType => String

def write(schema: StructType, className: String): String = {
run(schema, className)

private def run(schema: StructType, className: String): String = {
def genField(field: StructField): String = {
val converter = defaultTypeConverter(
val dataType = converter(field.dataType)
field match {
case x if x.nullable => s" ${}: Option[$dataType]"
case _ => s" ${}: $dataType"

val fieldsStr =",\n  ")
val schemaClass =
  s"""case class $className (
     |  $fieldsStr


private def defaultTypeConverter(colName: String): TypeConverter = {
val converter: TypeConverter = {
case _: ByteType => "Byte"
case _: ShortType => "Short"
case _: IntegerType => "Int"
case _: LongType => "Long"
case _: FloatType => "Float"
case _: DoubleType => "Double"
case _: DecimalType => "java.math.BigDecimal"
case _: StringType => "String"
case _: BinaryType => "Array[Byte]"
case _: BooleanType => "Boolean"
case _: TimestampType => "java.sql.Timestamp"
case _: DateType => "java.sql.Date"
case t: ArrayType =>
val e = t match {
case ArrayType(elementType, _) => elementType
case _: MapType => "scala.collection.Map"
case t: StructType =>
run(t, colName.capitalize)
case _ => "String"

val writer = new SchemaToCaseClassWriter()
val schema = // Your StructType schema
val className = "MyClass"
val caseClassString = writer.write(schema, className)
println(caseClassString) // Output the generated case class string

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment