Last active
August 30, 2016 16:36
-
-
Save schleumer/bbce61cce86bbf398b5624bd36658a68 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import slick.model.Model | |
import slick.driver.PostgresDriver | |
import Config._ | |
import scala.concurrent._ | |
import scala.concurrent.duration._ | |
import scala.concurrent.ExecutionContext.Implicits.global | |
import slick.codegen.SourceCodeGenerator | |
import slick.jdbc.meta.MTable | |
import slick.model.Column | |
/** | |
* This customizes the Slick code generator. We only do simple name mappings. | |
* For a more advanced example see https://github.com/cvogt/slick-presentation/tree/scala-exchange-2013 | |
*/ | |
object CustomizedCodeGenerator{ | |
def main(args: Array[String]): Unit = { | |
codegen.map(result => { | |
result.writeToFile( | |
"slick.driver.PostgresDriver", | |
args(0), | |
"db", | |
"Models", | |
"Models.scala" | |
) | |
}) | |
} | |
val db = PostgresDriver.api.Database.forURL(url,driver=jdbcDriver) | |
val codegen = db.run{ | |
PostgresDriver.defaultTables.flatMap( PostgresDriver.createModelBuilder(_, true).buildModel ) | |
}.map(model => new SourceCodeGenerator(model) { | |
// add custom import for added data types | |
//override def code = "import my.package.Java8DateTypes._" + "\n" + super.code | |
override def Table = new Table(_) { | |
table => | |
// Use different factory and extractor functions for tables with > 22 columns | |
override def factory = if(columns.size == 1) TableClass.elementType else if(columns.size <= 22) s"${TableClass.elementType}.tupled" else s"${EntityType.name}.apply" | |
override def extractor = if(columns.size <= 22) s"${TableClass.elementType}.unapply" else s"${EntityType.name}.unapply" | |
override def EntityType = new EntityTypeDef { | |
override def code = { | |
val args = columns.map(c => | |
c.default.map( v => | |
s"${c.name}: ${c.exposedType} = $v" | |
).getOrElse( | |
s"${c.name}: ${c.exposedType}" | |
) | |
) | |
val callArgs = columns.map(c => s"${c.name}") | |
val types = columns.map(c => c.exposedType) | |
if(classEnabled){ | |
val prns = (parents.take(1).map(" extends "+_) ++ parents.drop(1).map(" with "+_)).mkString("") | |
s"""case class $name(${args.mkString(", ")})$prns""" | |
} else { | |
s""" | |
/** Constructor for $name providing default values if available in the database schema. */ | |
case class $name(${args.map(arg => {s"$arg"}).mkString(", ")}) | |
type ${name}List = ${compoundType(types)} | |
object $name { | |
def apply(hList: ${name}List): $name = hList match { | |
case ${compoundValue(callArgs)} => new $name(${callArgs.mkString(", ")}) | |
case _ => throw new Exception("malformed HList") | |
} | |
def unapply(row: $name) = Some(${compoundValue(callArgs.map(a => s"row.$a"))}) | |
} | |
""".trim | |
} | |
} | |
} | |
override def PlainSqlMapper = new PlainSqlMapperDef { | |
override def code = { | |
val positional = compoundValue(columnsPositional.map(c => if (c.fakeNullable || c.model.nullable) s"<<?[${c.rawType}]" else s"<<[${c.rawType}]")) | |
val dependencies = columns.map(_.exposedType).distinct.zipWithIndex.map{ case (t,i) => s"""e$i: GR[$t]"""}.mkString(", ") | |
val rearranged = compoundValue(desiredColumnOrder.map(i => if(columns.size > 22) s"r($i)" else tuple(i))) | |
def result(args: String) = s"$factory($args)" | |
val body = | |
if(autoIncLastAsOption && columns.size > 1){ | |
s""" | |
val r = $positional | |
import r._ | |
${result(rearranged)} // putting AutoInc last | |
""".trim | |
} else { | |
result(positional) | |
} | |
s""" | |
implicit def $name(implicit $dependencies): GR[${TableClass.elementType}] = GR{ | |
prs => import prs._ | |
${indent(body)} | |
} | |
""".trim | |
} | |
} | |
override def TableClass = new TableClassDef { | |
override def star = { | |
val struct = compoundValue(columns.map(c=>if(c.fakeNullable)s"Rep.Some(${c.name})" else s"${c.name}")) | |
val rhs = s"$struct <> ($factory, $extractor)" | |
s"def * = $rhs" | |
} | |
} | |
def tails(n: Int) = { | |
List.fill(n)(".tail").mkString("") | |
} | |
// override column generator to add additional types | |
// override def Column = new Column(_) { | |
// override def rawType = { | |
// typeMapper(model).getOrElse(super.rawType) | |
// } | |
// } | |
} | |
}) | |
// def typeMapper(column: Column): Option[String] = { | |
// column.tpe match { | |
// case "java.sql.Date" => Some("java.time.LocalDate") | |
// case "java.sql.Timestamp" => Some("java.time.LocalDateTime") | |
// case _ => None | |
// } | |
// } | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment