Skip to content

Instantly share code, notes, and snippets.

@yoyama
Created January 20, 2017 07:36
Show Gist options
  • Save yoyama/ce83f688717719fc8ca145c3b3ff43fd to your computer and use it in GitHub Desktop.
Save yoyama/ce83f688717719fc8ca145c3b3ff43fd to your computer and use it in GitHub Desktop.
Generate case class from spark DataFrame/Dataset schema.
/**
* Generate Case class from DataFrame.schema
*
* val df:DataFrame = ...
*
* val s2cc = new Schema2CaseClass
* import s2cc.implicit._
*
* println(s2cc.schemaToCaseClass(df.schema, "MyClass"))
*
*/
import org.apache.spark.sql.types._
class Schema2CaseClass {
type TypeConverter = (DataType) => String
def schemaToCaseClass(schema:StructType, className:String)(implicit tc:TypeConverter):String = {
def genField(s:StructField):String = {
val f = tc(s.dataType)
s match {
case x if(x.nullable) => s" ${s.name}:Option[$f]"
case _ => s" ${s.name}:$f"
}
}
val fieldsStr = schema.map(genField).mkString(",\n ")
s"""
|case class $className (
| $fieldsStr
|)
""".stripMargin
}
object implicits {
implicit val defaultTypeConverter:TypeConverter = (t:DataType) => { t match {
case _:ByteType => "Byte"
case _:ShortType => "Short"
case _:IntegerType => "Int"
case _:LongType => "Long"
case _:FloatType => "Float"
case _:DoubleType => "Double"
case _:DecimalType => "java.math.BigDecimal"
case _:StringType => "String"
case _:BinaryType => "Array[Byte]"
case _:BooleanType => "Boolean"
case _:TimestampType => "java.sql.Timestamp"
case _:DateType => "java.sql.Date"
case _:ArrayType => "scala.collection.Seq"
case _:MapType => "scala.collection.Map"
case _:StructType => "org.apache.spark.sql.Row"
case _ => "String"
}}
}
}
@nicosuave
Copy link

This is awesome, especially for use with frameless– thanks

@selmahfo
Copy link

selmahfo commented Nov 9, 2017

Thanks for the script came in handy! I'm new to spark with scala but i think in the example you gave you should change :
import s2cc.implicit._ with import s2cc.implicits._

@byanjati
Copy link

byanjati commented Dec 9, 2017

This is cool, save my time from writing a lot of case class

@RayTsui
Copy link

RayTsui commented Apr 6, 2018

After getting the string of executable code of case class, how to execute the string? Scala reflect or something else?

@shivakomat
Copy link

@RayTsui you write the string to a file or append to file of case classes.

@gstaubli
Copy link

gstaubli commented May 3, 2018

I love this conversion process! I've had very nested schemas, which required me to manually run this code on different levels of nesting. Would love a recursive or other version to handle nested schemas, which I hope to contribute back unless someone beats me to it ;)

@exp-smurshed
Copy link

I concur with @gstaubli. Can you please share what you did for the nested schema?

@skoppar
Copy link

skoppar commented Jun 21, 2018

Very good idea. I was also looking for some way to execute the case class creation method and found this:
import scala.tools.reflect.ToolBox
import scala.reflect.runtime.universe._
import scala.reflect.runtime.currentMirror

val df = ....
val toolbox = currentMirror.mkToolBox()
val case_class = toolbox.compile(f.schemaToCaseClass(dfschema, "YourName"))

The return type of schemaToCaseClass would have to be runtime.universe.Tree and we would use Quasiquotes

def schemaToCaseClass(schema:StructType, className:String)(implicit tc:TypeConverter) :runtime.universe.Tree= {
def genField(s:StructField):String = {
val f = tc(s.dataType)
s match {
case x if(x.nullable) => s" ${s.name}:Option[$f]"
case _ => s" ${s.name}:$f"
}
}

  val fieldsStr = schema.map(genField).mkString(",\n  ")
  q"""
     case class $className (
       $fieldsStr
     )"""
}

However, I was trying to apply it back to resulting dataframe and I dont see a way to do that. Sharing whatever I found, in case it helps someone
Reference - https://stackoverflow.com/questions/31054237/what-are-the-ways-to-convert-a-string-into-runnable-code

@geoHeil
Copy link

geoHeil commented Jun 26, 2018

Would it be possible to create a Macro? I can't seem to be able to actually make use of the class string generated as it won't compile https://stackoverflow.com/questions/51035313/dynamically-create-case-class-from-structtype#51035313

@BioQwer
Copy link

BioQwer commented Nov 15, 2018

scala> import org.apache.spark.sql.types._
import org.apache.spark.sql.types._

scala>

scala> class Schema2CaseClass {
     |   type TypeConverter = (DataType) => String
     |
     |   def schemaToCaseClass(schema:StructType, className:String)(implicit tc:TypeConverter):String = {
     |     def genField(s:StructField):String = {
     |       val f = tc(s.dataType)
     |       s match {
     |         case x if(x.nullable) => s"  ${s.name}:Option[$f]"
     |         case _ => s"  ${s.name}:$f"
     |       }
     |     }
     |
     |     val fieldsStr = schema.map(genField).mkString(",\n  ")
     |     s"""
     |        |case class $className (
     |        |  $fieldsStr
     |        |)
     |   """.stripMargin
     |   }
     |
     |   object implicits {
     |     implicit val defaultTypeConverter:TypeConverter = (t:DataType) => { t match {
     |       case _:ByteType => "Byte"
     |       case _:ShortType => "Short"
     |       case _:IntegerType => "Int"
     |       case _:LongType => "Long"
     |       case _:FloatType => "Float"
     |       case _:DoubleType => "Double"
     |       case _:DecimalType => "java.math.BigDecimal"
     |       case _:StringType => "String"
     |       case _:BinaryType => "Array[Byte]"
     |       case _:BooleanType => "Boolean"
     |       case _:TimestampType => "java.sql.Timestamp"
     |       case _:DateType => "java.sql.Date"
     |       case _:ArrayType => "scala.collection.Seq"
     |       case _:MapType => "scala.collection.Map"
     |       case _:StructType => "org.apache.spark.sql.Row"
     |       case _ => "String"
     |     }}
     |   }
     | }
<console>:12: error: not found: type DataType
  type TypeConverter = (DataType) => String
                        ^
<console>:14: error: not found: type StructType
  def schemaToCaseClass(schema:StructType, className:String)(implicit tc:TypeConverter):String = {
                               ^
<console>:15: error: not found: type StructField
    def genField(s:StructField):String = {
                   ^
<console>:32: error: not found: type DataType
    implicit val defaultTypeConverter:TypeConverter = (t:DataType) => { t match {
                                                         ^
<console>:33: error: not found: type ByteType
      case _:ByteType => "Byte"
             ^
<console>:34: error: not found: type ShortType
      case _:ShortType => "Short"
             ^
<console>:35: error: not found: type IntegerType
      case _:IntegerType => "Int"
             ^
<console>:36: error: not found: type LongType
      case _:LongType => "Long"
             ^
<console>:37: error: not found: type FloatType
      case _:FloatType => "Float"
             ^
<console>:38: error: not found: type DoubleType
      case _:DoubleType => "Double"
             ^
<console>:39: error: not found: type DecimalType
      case _:DecimalType => "java.math.BigDecimal"
             ^
<console>:40: error: not found: type StringType
      case _:StringType => "String"
             ^
<console>:41: error: not found: type BinaryType
      case _:BinaryType => "Array[Byte]"
             ^
<console>:42: error: not found: type BooleanType
      case _:BooleanType => "Boolean"
             ^
<console>:43: error: not found: type TimestampType
      case _:TimestampType => "java.sql.Timestamp"
             ^
<console>:44: error: not found: type DateType
      case _:DateType => "java.sql.Date"
             ^
<console>:45: error: not found: type ArrayType
      case _:ArrayType => "scala.collection.Seq"
             ^
<console>:46: error: not found: type MapType
      case _:MapType => "scala.collection.Map"
             ^
<console>:47: error: not found: type StructType
      case _:StructType => "org.apache.spark.sql.Row"
             ^

If you have this mistake use in console :paste, in this way it's works for me.

@BioQwer
Copy link

BioQwer commented Nov 15, 2018

I concur with @gstaubli. Can you please share what you did for the nested schema?

+1

@eschombu
Copy link

eschombu commented Aug 7, 2019

This is really helpful, but here's an improvement. In defaultTypeConverter, change the ArrayType case to

      case _: ArrayType => {
        val e = t match { case ArrayType(elementType, _) => elementType }
        s"Seq[${defaultTypeConverter(e)}]"
      }

@zpappa
Copy link

zpappa commented Apr 26, 2020

Perhaps I did something wrong here, but was unable to get this working with the implicits statement
I had to explicitly pass it in as below, for anyone who had the same issue and received an identifier expected but 'implicit' found.

println(s2cc.schemaToCaseClass(schema, "MyclassName")(s2cc.implicits.defaultTypeConverter))

That said, super helpful, thanks!

@JituS
Copy link

JituS commented Apr 13, 2021

Thanks for sharing this. As small enhancement could be, if there are nested StructType in a schema. I have tried incorporating that scenario below:

import java.io.FileWriter

import org.apache.spark.sql.types._

class SchemaToCaseClassWriter(fileWriter: FileWriter) {
  type TypeConverter = DataType => String

  def write(schema: StructType, className: String): Unit = {
    run(schema, className)
    fileWriter.close()
  }

  private def run(schema: StructType, className: String): Unit = {
    def genField(field: StructField): String = {
      val converter = defaultTypeConverter(field.name)
      val dataType = converter(field.dataType)
      field match {
        case x if x.nullable => s"  ${field.name}:Option[$dataType]"
        case _ => s"  ${field.name}:$dataType"
      }
    }

    val fieldsStr = schema.map(genField).mkString(",\n  ")
    val schemaClass =
      s"""case class $className (
         |  $fieldsStr
         |)
         |
         |""".stripMargin
    fileWriter.write(schemaClass)
  }

  private def defaultTypeConverter(colName: String): TypeConverter = {
    val converter: TypeConverter = {
      case _: ByteType => "Byte"
      case _: ShortType => "Short"
      case _: IntegerType => "Int"
      case _: LongType => "Long"
      case _: FloatType => "Float"
      case _: DoubleType => "Double"
      case _: DecimalType => "java.math.BigDecimal"
      case _: StringType => "String"
      case _: BinaryType => "Array[Byte]"
      case _: BooleanType => "Boolean"
      case _: TimestampType => "java.sql.Timestamp"
      case _: DateType => "java.sql.Date"
      case t: ArrayType =>
        val e = t match {
          case ArrayType(elementType, _) => elementType
        }
        s"Seq[${defaultTypeConverter(colName)(e)}]"
      case _: MapType => "scala.collection.Map"
      case t: StructType =>
        run(t, colName.capitalize)
        colName.capitalize
      case _ => "String"
    }
    converter
  }

@maxmithun
Copy link

Schema with nested structure is having a struct with the same name at different levels, then 2 class with the same name will be created. This will break the schema when used. I think we need to use package name to handle that . Any other alternatives ?

@srimunugoti
Copy link

How to use the resultant string as case class any example pls

@7873737376
Copy link

7873737376 commented Feb 8, 2024

/*we can use below code directly it will return string instead of writing into file we can get string in a variable */

import org.apache.spark.sql.types._

class SchemaToCaseClassWriter {
type TypeConverter = DataType => String

def write(schema: StructType, className: String): String = {
run(schema, className)
}

private def run(schema: StructType, className: String): String = {
def genField(field: StructField): String = {
val converter = defaultTypeConverter(field.name)
val dataType = converter(field.dataType)
field match {
case x if x.nullable => s" ${field.name}: Option[$dataType]"
case _ => s" ${field.name}: $dataType"
}
}

val fieldsStr = schema.map(genField).mkString(",\n  ")
val schemaClass =
  s"""case class $className (
     |  $fieldsStr
     |)
     |
     |""".stripMargin
schemaClass

}

private def defaultTypeConverter(colName: String): TypeConverter = {
val converter: TypeConverter = {
case _: ByteType => "Byte"
case _: ShortType => "Short"
case _: IntegerType => "Int"
case _: LongType => "Long"
case _: FloatType => "Float"
case _: DoubleType => "Double"
case _: DecimalType => "java.math.BigDecimal"
case _: StringType => "String"
case _: BinaryType => "Array[Byte]"
case _: BooleanType => "Boolean"
case _: TimestampType => "java.sql.Timestamp"
case _: DateType => "java.sql.Date"
case t: ArrayType =>
val e = t match {
case ArrayType(elementType, _) => elementType
}
s"Seq[${defaultTypeConverter(colName)(e)}]"
case _: MapType => "scala.collection.Map"
case t: StructType =>
run(t, colName.capitalize)
colName.capitalize
case _ => "String"
}
converter
}
}

val writer = new SchemaToCaseClassWriter()
val schema = // Your StructType schema
val className = "MyClass"
val caseClassString = writer.write(schema, className)
println(caseClassString) // Output the generated case class string

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment