Skip to content

Instantly share code, notes, and snippets.

@lancegatlin
Created April 29, 2016 13:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lancegatlin/247fadb3768ccf8e9636e4c9b546eae8 to your computer and use it in GitHub Desktop.
Save lancegatlin/247fadb3768ccf8e9636e4c9b546eae8 to your computer and use it in GitHub Desktop.
package s_mach.codetools.bigcaseclass
import s_mach.string._
object BigCaseClassPrinter {
def print(name: String, fields: Vector[CaseClassField]) : String = {
val subCaseClasses = fields.grouped(CASECLASS_MAX_FIELDS).toVector
/*
_1 : Licensee.Licensee1,
_2 : Licensee.Licensee2
*/
val fieldsStr = subCaseClasses.zipWithIndex.map { case (s,i) =>
Seq(s"_${i+1}",":",s"$name.$name${i+1}")
}.printGrid(" ",",\n")
/*
def licenseeId = _1.licenseeId
def code = _1.code
def name = _1.name
...
def consoleAccessRate = _2.consoleAccessRate
def nonAgentAccessRate = _2.nonAgentAccessRate
*/
val methodsStr = subCaseClasses.iterator.zipWithIndex.flatMap { case(subfields,i) =>
subfields.map { f =>
Seq("def",f.name,":",f._type,"=",s"_${i+1}.${f.name}", f.optComment.map("// " + _).getOrElse(""))
}
}.toVector.printGrid(" ","\n")
/*
override def productElement(i: Int) : Any = i match {
case n if n < 22 => _1.productElement(i)
case n => _2.productElement(i - 22)
}
override def productArity : Int = 33
override def productIterator = _1.productIterator ++ _2.productIterator
*/
val productMethodsStr : String = {
val cases = subCaseClasses.iterator.zipWithIndex.map { case (_,i) =>
s"case n if n < ${(i+1)*CASECLASS_MAX_FIELDS} => _${i+1}.productElement(i - ${i*CASECLASS_MAX_FIELDS})"
}.mkString("\n").indent(2)
val iterators = subCaseClasses.indices.map(i => s"_${i+1}.productIterator").mkString(" ++ ")
s"""
|override def productElement(i: Int) : Any = i match {
|$cases
| case _ => throw new IndexOutOfBoundsException
|}
|override def productArity : Int = ${fields.size}
|override def productIterator = $iterators
""".stripMargin.trim
}
/*
case class Licensee1(
licenseeId : Long,
code : String,
name : scala.Option[String],
shortName : scala.Option[String] = None,
...
case class Licensee2(
ssoProvider : scala.Option[String],
consoleAccess : scala.Option[Byte],
consoleAccessRate : Double,
nonAgentAccessRate : scala.Option[Double],
)
*/
val subCaseClassesStr = {
subCaseClasses.iterator.zipWithIndex.map { case(subfields,i) =>
CaseClassPrinter.printNormCaseClass(name + (i+1).toString,subfields)
}.mkString("\n")
}
// Copy method that takes all parameters
val bigCopyStr = {
/*
licenseeId : Long = _1.licenseeId,,
code : String = _1.code,
name : scala.Option[String] = _1.name,
shortName : scala.Option[String] = _1.shortName,
...
*/
val copyParmsStr = subCaseClasses.iterator.zipWithIndex.flatMap { case(subfields,i) =>
subfields.map { f =>
Seq(f.name,":",f._type,"=",s"_${i+1}.${f.name}")
}
}.toVector.printGrid(" ",",\n")
val caseClassParmsStr = subCaseClasses.iterator.zipWithIndex.flatMap { case(subfields,i) =>
subfields.map { f =>
Seq(f.name,"=",f.name)
}
}.toVector.printGrid(" ",",\n")
s"""
|def copy(
|${copyParmsStr.indent(2)}
|) : $name = $name(
|${caseClassParmsStr.indent(2)}
|)
""".stripMargin.trim
}
// Apply method that takes all parameters
val bigApplyStr = {
/*
licenseeId : Long,
code : String,
name : scala.Option[String],
shortName : scala.Option[String] = None,
...
*/
val allFieldsStr = CaseClassPrinter.printFieldDecls(fields)
/*
_1 = Licensee1(
licenseeId = licenseeId,
code = code,
...
),
2 = Licensee2(
ssoProvider = ssoProvider,
...
)
*/
val applySubCaseClassesStr = subCaseClasses.iterator.zipWithIndex.map { case (subfields,i) =>
val subFieldParms = subfields.indices.map { j =>
val v = Vector(subfields(j).name,"=",subfields(j).name)
if(j == subfields.indices.last) {
v
} else {
v.updated(v.indices.last,v.last + ",")
}
}.printGrid(" ","\n")
s"""
|_${i+1} = $name${i+1}(
|${subFieldParms.indent(2)}
|)
""".stripMargin.trim
}.mkString(",\n")
s"""
|def apply(
|${allFieldsStr.indent(2)}
|) : $name = $name(
|${applySubCaseClassesStr.indent(2)}
|)
""".stripMargin.trim
}
s"""
|case class $name(
|${fieldsStr.indent(2)}
|) {
|${methodsStr.indent(2)}
|
|${bigCopyStr.indent(2)}
|
|${productMethodsStr.indent(2)}
|}
|
|object $name {
|${subCaseClassesStr.indent(2)}
|${bigApplyStr.indent(2)}
|}
""".stripMargin.trim
}
}
package s_mach.codetools.bigcaseclass
case class CaseClassField(
name: String,
_type: String,
optDefault: Option[String],
optComment: Option[String]
)
package s_mach.codetools.bigcaseclass
import s_mach.string._
object CaseClassPrinter {
def print(name: String, fields: Vector[CaseClassField]) : String = {
if(fields.size <= CASECLASS_MAX_FIELDS) {
printNormCaseClass(name, fields)
} else {
BigCaseClassPrinter.print(name, fields)
}
}
def printNormCaseClass(name: String, fields: Vector[CaseClassField]) : String = {
val fieldsStr = printFieldDecls(fields)
s"""
|case class $name(
|${fieldsStr.indent(2)}
|)
""".stripMargin.trim
}
def printFieldDecls(fields: Vector[CaseClassField]) : String = {
val atLeastOneDefault = fields.exists(_.optDefault.nonEmpty)
fields.indices.map { j =>
val baseFieldDecl = Vector(
fields(j).name,
":",
fields(j)._type
) ++ fields(j).optDefault.map("= " + _).toVector
val fieldDeclWithComma =
if(j == fields.indices.last) {
baseFieldDecl
} else {
baseFieldDecl.updated(baseFieldDecl.indices.last,baseFieldDecl.last + ",")
}
if(fields(j).optDefault.isEmpty && atLeastOneDefault) {
fieldDeclWithComma ++ Vector("") ++ fields(j).optComment.map("// " + _).toVector
} else {
fieldDeclWithComma ++ fields(j).optComment.map("// " + _).toVector
}
}.printGrid(" ","\n")
}
}
package s_mach.codetools.bigcaseclass
import s_mach.string._
object DdlToCaseClassPrinter {
case class Config(
formatTableName : String => String = {
import WordSplitter.Underscore
_.toCamelCase
},
formatColumnName : String => String = {
import WordSplitter.Underscore
_.toCamelCase
},
sqlToScalaTypeMap: Map[String, String] = stdSqlToScalaTypeMap
)
val scala_String = "String"
val scala_ArrayByte = "Array[Byte]"
val scala_Boolean = "Boolean"
val scala_Byte = "Byte"
val scala_Short = "Short"
val scala_Int = "Int"
val scala_Long = "Long"
val scala_BigInt = "BigInt"
val scala_Float = "Float"
val scala_Double = "Double"
val scala_BigDecimal = "BigDecimal"
val java_util_Date = "java.util.Date"
val stdSqlToScalaTypeMap : Map[String, String] = Map(
"char" -> scala_String,
"varchar" -> scala_String,
"tinytext" -> scala_String,
"text" -> scala_String,
"mediumtext" -> scala_String,
"longtext" -> scala_String,
"clob" -> scala_String,
"set" -> scala_String,
"enum" -> scala_String,
"blob" -> scala_ArrayByte,
"mediumblob" -> scala_ArrayByte,
"bit" -> scala_Boolean,
"tinyint(1)" -> scala_Boolean,
"unsigned tinyint(1)" -> scala_Boolean,
"tinyint" -> scala_Byte,
"unsigned tinyint" -> scala_Short, // scala has no concept of unsigned so need to promote
"smallint" -> scala_Short,
"unsigned smallint" -> scala_Int, // scala has no concept of unsigned so need to promote
"mediumint" -> scala_Int,
"unsigned mediumint" -> scala_Int, // scala has no concept of unsigned so need to promote
"int" -> scala_Int,
"unsigned int" -> scala_Long, // scala has no concept of unsigned so need to promote
"bigint" -> scala_BigInt,
"unsigned bigint" -> scala_BigInt,
"float" -> scala_Float,
"double" -> scala_Double,
"unsigned double" -> scala_Double,
"decimal" -> scala_BigDecimal,
"date" -> java_util_Date,
"datetime" -> java_util_Date,
"timestamp" -> scala_Long,
"time" -> java_util_Date,
"year" -> "Int"
)
// TODO: replace with real DDL parser
val parseCreateTableRegex = "(?i)CREATE TABLE [`]?(\\w+)[`]?\\s*\\((.+)\\)".r
val parseColumnDeclRegex = "(?i)[`]?(\\w+)[`]?\\s+(\\w+)\\s*(\\(.+?\\))?([^,]*)[,]".r
val parseColumnDeclFilter = "(?i)(?<=(\\s|^))(PRIMARY|KEY|CONSTRAINT|FOREIGN|USING)(?=(\\s|$))".r
val parseDefaultRegex = "(?i)DEFAULT (NULL|'.*?')".r
/** @return a case class for the given SQL DDL */
def print(
ddl: String,
cfg: Config = Config()
) : String = {
import cfg._
val tidyDdl = ddl.replaceAllLiterally("\n"," ").replaceAll("\\s+"," ")
parseCreateTableRegex.findAllMatchIn(tidyDdl).map { tblMatch =>
val tableName = tblMatch.group(1)
val columns = tblMatch.group(2)
val fields : Vector[CaseClassField] = {
// Parse out column declarations and filter unintentional matches to lines like "PRIMARY KEY"
val columnDecls =
parseColumnDeclRegex
.findAllMatchIn(columns)
.filter(m =>
parseColumnDeclFilter.findFirstIn(m.group(0)).isEmpty
)
columnDecls.zipWithIndex.map { case (columnMatch,i) =>
val comment = columnMatch.group(0)
val columnName = columnMatch.group(1)
val rawColumnType = columnMatch.group(2)
val columnTypeMod = columnMatch.group(3)
val suffix = columnMatch.group(4)
val lcSuffix = suffix.toLowerCase
val isNullable = lcSuffix.contains("not null") == false
val columnType = {
{if(lcSuffix.contains("unsigned")) {
"unsigned "
} else {
""
}} +
{if(rawColumnType.equalsIgnoreCase("tinyint") && columnTypeMod == "(1)") {
"tinyint(1)"
} else {
rawColumnType
}}
}
val baseScalaType = sqlToScalaTypeMap.getOrElse(
columnType,
throw new RuntimeException(s"Unmapped SQL type: $columnType! ${columnMatch.group(0)}")
)
// Parse the column default - some tricky logic here
val optDefault = {
parseDefaultRegex.findFirstMatchIn(suffix) match {
case Some(m) => Some {
// Translate NULL to None
if(m.group(1).equalsIgnoreCase("NULL")) {
"None"
} else {
// Strip quotes
val rawDefault = m.group(1).tail.init
// Adjust the scala value based on the baseScalaType
val baseDefault =
baseScalaType match {
// Strings need to be double-quoted in scala
case "String" => '"' + rawDefault.toString + '"'
// Chars need to be single-quoted in scala
case "Char" => s"'$rawDefault'"
// SQL uses 0 for false and any other value as true
case "Boolean" => rawDefault match {
case "0" => "false"
case _ => "true"
}
case _ => rawDefault
}
// If the column is nullable then wrap the default value in Some
if(isNullable) {
s"Some($baseDefault)"
} else {
baseDefault
}
}
}
case None => None
}
}
val scalaType = if(isNullable) {
s"Option[$baseScalaType]"
} else {
baseScalaType
}
CaseClassField(
name = formatTableName(columnName),
_type = scalaType,
optDefault = optDefault,
optComment = Some(s"$i $comment")
)
}
}.toVector
val caseClassName = formatColumnName(tableName)
val caseClassStr = CaseClassPrinter.print(caseClassName, fields)
val now = new java.util.Date()
s"""
|/**
| * Case class for a row in table $tableName
| * WARN: auto-generated using net.tstllc.codegen.DdlToCaseClassPrinter
| * WARN: field order MUST correspond to SQL column order
| * Regex for quick find/replace:
| * ${"(\\w+)\\s*:\\s*(.+?)(\\s*=\\s*(.+?))*[,]*\\s* // (\\d+)"}
| * $now
| **/
|$caseClassStr
|/* Auto-generated from:
|$ddl
|*/
""".stripMargin.trim
}.mkString("\n")
}
}
package s_mach.codetools
package object bigcaseclass {
val CASECLASS_MAX_FIELDS = 22
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment