Skip to content

Instantly share code, notes, and snippets.

@nafg
Last active August 17, 2020 21:59
Show Gist options
  • Save nafg/883814df176e0cec495429806a1e01f2 to your computer and use it in GitHub Desktop.
Save nafg/883814df176e0cec495429806a1e01f2 to your computer and use it in GitHub Desktop.
Simple code generator for Slick using Scalameta rather than strings, as an SBT plugin
import java.sql.Types
import scala.annotation.tailrec
import scala.concurrent.Await
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration.Duration
import scala.meta._
import slick.dbio.DBIO
import slick.jdbc.meta.{MColumn, MQName, MTable}
import slick.jdbc.{JdbcBackend, JdbcProfile}
import sbt.Keys._
import sbt._
object SlickMetaGenPlugin extends AutoPlugin {
override def requires = SlickConfigPlugin
object autoImport {
case class ColumnInfo(columnName: String,
tableFieldTerm: Term.Name,
rowFieldTerm: Term.Name,
scalaType: Type,
scalaDefault: Option[Term])
case class SchemaInfo(tableName: MQName, tableClassName: String, rowClassName: String, columns: List[ColumnInfo])
val slickPackage = settingKey[String]("The package to put the definitions in")
val slickContainer =
settingKey[String]("The container object / trait to put the table definitions in (also determines the filename)")
val slickProfileClass = settingKey[String]("The slick profile class")
val slickTables = settingKey[MTable => Boolean]("Which tables to codegen")
val slickSchemaInfo = taskKey[List[SchemaInfo]]("Generate the SchemaInfo")
val slickMetaGenModels = taskKey[Seq[File]]("Generate type definitions from the database")
val slickMetaGenTables = taskKey[Seq[File]]("Generate schema definitions from the database")
val slickMetaGenExtraImports = settingKey[List[String]]("Extra imports")
}
import SlickConfigPlugin.autoImport._
import autoImport._
def toCamel(s: String) = {
def loop(cs: List[Char]): List[Char] =
cs match {
case '_' :: c :: rest => c.toUpper :: loop(rest)
case c :: rest => c :: loop(rest)
case Nil => Nil
}
loop(s.toList).mkString
}
def getColumnInfo: MColumn => ColumnInfo = {
case c @ MColumn(_, name, sqlType, typeName, _, _, _, nullable, _, columnDef, _, _, _, _, _, isAutoInc) =>
val defaultNotAuto = if (isAutoInc.contains(true)) None else columnDef
val (typ0, default0) =
(sqlType, typeName) match {
case (_, "lo") => t"java.sql.Blob" -> None
case (Types.NUMERIC, "numeric") => t"BigDecimal" -> defaultNotAuto.map(s => q"BigDecimal($s)")
case (Types.DOUBLE, "float8") => t"Double" -> defaultNotAuto.map(s => Lit.Double(s.toDouble))
case (Types.BIT, "bool") => t"Boolean" -> defaultNotAuto.map(s => Lit.Boolean(s.toBoolean))
case (Types.INTEGER, _) => t"Int" -> defaultNotAuto.map(s => Lit.Int(s.toInt))
case (Types.VARCHAR, "varchar" | "text") =>
t"String" -> defaultNotAuto.map(s => Lit.String(s.stripPrefix("'").stripSuffix("'")))
case (Types.DATE, "date") =>
t"java.time.LocalDate" ->
defaultNotAuto.collect { case "now()" => q"java.time.LocalDate.now()" }
case (_, _) =>
System.err.println("Don't know how to handle " + c)
t"Nothing" -> None
}
val (typ, default) =
if (nullable.contains(true))
t"Option[$typ0]" -> Some(default0.map(t => q"Some($t)").getOrElse(q"None"))
else
typ0 -> default0
val ident = Term.Name(toCamel(name))
ColumnInfo(name, ident, ident, typ, default)
}
def getSchemaInfo(table: MTable) =
table.getColumns.map(_.toList.map(getColumnInfo))
.map { colInfos =>
val ident = toCamel(table.name.name.capitalize)
SchemaInfo(table.name, ident, ident + "Row", colInfos)
}
def rowStats(schemaInfo: SchemaInfo): List[Stat] = {
val params = schemaInfo.columns.map { col =>
Term.Param(Nil, col.rowFieldTerm, Some(col.scalaType), col.scalaDefault)
}
List(
q"""
@JsonCodec
case class ${Type.Name(schemaInfo.rowClassName)}(..$params)
"""
)
}
def isDefaultSchema(schema: String) = schema == "public"
def mkStar(rowClassName: String, columns: List[ColumnInfo]) = {
val companion = Term.Name(rowClassName)
val terms = columns.map(_.tableFieldTerm)
val numCols = columns.length
val (tuple, factory, extractor) =
if (numCols <= 22)
(Term.Tuple(terms), q"($companion.apply _).tupled", q"$companion.unapply")
else {
@tailrec
def group22[A](values: List[A])(group: List[A] => A): A = values match {
case List(one) => one
case _ =>
val (first, second) = values.splitAt(22)
group22(group(first) +: second)(group)
}
(group22[Term](terms)(Term.Tuple(_)),
Term.PartialFunction(
List(
p"""
case ${group22[Pat](terms.map(Pat.Var(_)))(Pat.Tuple(_))} =>
$companion(..$terms)
"""
)
),
q"(rec: ${Type.Name(rowClassName)}) => Some(${group22[Term](terms.map(t => q"rec.$t"))(Term.Tuple(_))})")
}
q"def * = $tuple.<>({$factory}, $extractor)"
}
def tableStats: SchemaInfo => List[Stat] = {
case SchemaInfo(tableName, tableClassName, rowClassName, columns) =>
val fields = columns.map {
case ColumnInfo(columnName, tableFieldName, _, scalaType, _) =>
q"""
val ${Pat.Var(tableFieldName)} = column[$scalaType]($columnName)
"""
}
val star = mkStar(rowClassName, columns)
val params = tableName match {
case MQName(None, Some(schema), name) if !isDefaultSchema(schema) => List(q"Some($schema)", Lit.String(name))
case MQName(None, _, name) => List(Lit.String(name))
case MQName(Some(_), _, _) => sys.error("catalog not supported")
}
List(
q"""
class ${Type.Name(tableClassName)}(_tableTag: Tag)
extends Table[${Type.Name(rowClassName)}](_tableTag, ..$params) {
$star
..$fields
}
""",
q"""
lazy val ${Pat.Var(Term.Name(tableClassName))} = TableQuery[${Type.Name(tableClassName)}]
"""
)
}
def toRef(s: String): Term.Ref = {
def loop(last: String, revInit: List[String]): Term.Ref = revInit match {
case Nil => Term.Name(last)
case x :: xs => Term.Select(loop(x, xs), Term.Name(last))
}
val last :: revInit = s.split('.').toList.reverse
loop(last, revInit)
}
def imports(strings: List[String]): List[Stat] =
if (strings.isEmpty)
Nil
else
List(q"import ..${strings.map(_.parse[Importer].get)}")
override def projectSettings =
Seq(
slickContainer := "Tables",
slickProfileClass := slickConfig.value.getString("profile"),
slickTables := (_.name != MQName(None, None, "flyway_schema_history")),
slickMetaGenExtraImports := Nil,
slickSchemaInfo := {
val config = slickConfig.value
val profileName = slickProfileClass.value
val tablesPred = slickTables.value
val slickProfile = Class.forName(profileName).getField("MODULE$").get(null).asInstanceOf[JdbcProfile]
val db = JdbcBackend.Database.forConfig("", config)
try {
val tablesAction = slickProfile.defaultTables.map(_.filter(tablesPred))
val infoAction = tablesAction.flatMap(tables => DBIO.sequence(tables.toList.map(getSchemaInfo)))
Await.result(db.run(infoAction), Duration.Inf)
} finally db.close()
},
slickMetaGenModels := {
val container = slickContainer.value
val outputDir = (Compile / sourceManaged).value
val pkg = slickPackage.value
val schemaInfos = slickSchemaInfo.value
val filename = container + ".scala"
val file = outputDir / pkg.replace(".", "/") / filename
IO.write(
file,
q"""
package ${toRef(pkg)} {
import io.circe.generic.JsonCodec
..${imports(slickMetaGenExtraImports.value)}
..${schemaInfos.flatMap(rowStats)}
}
""".syntax
)
Seq[File](file)
},
slickMetaGenTables := {
val container = slickContainer.value
val outputDir = (Compile / sourceManaged).value
val pkg = slickPackage.value
val slickProfileName = toRef(slickProfileClass.value.stripSuffix("$"))
val schemaInfos = slickSchemaInfo.value
val filename = container + ".scala"
val file = outputDir / pkg.replace(".", "/") / filename
IO.write(
file,
q"""
package ${toRef(pkg)} {
import $slickProfileName.api._
..${imports(slickMetaGenExtraImports.value)}
object ${Term.Name(container)} {
..${schemaInfos.flatMap(tableStats)}
}
}
""".syntax
)
Seq[File](file)
}
)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment