Skip to content

Instantly share code, notes, and snippets.

@chuwy
Last active August 22, 2023 19:54
Show Gist options
  • Save chuwy/8f664c5e2ac57d4513702beb9e5f261e to your computer and use it in GitHub Desktop.
Save chuwy/8f664c5e2ac57d4513702beb9e5f261e to your computer and use it in GitHub Desktop.
Scala macro for building type-safe SQL Fragments
import java.time.LocalDateTime
import scala.quoted.*
import scala.deriving.Mirror
import quotidian.*
import quotidian.syntax.*
import io.github.iltotore.iron.*
import cats.data.NonEmptyList
import skunk.*
import skunk.codec.numeric.int8
import skunk.implicits.*
import io.foldables.ratio.common.Primitive
trait Table(val tableName: Table.Name, val columns: Table.Columns) extends Selectable:
/** Union type of all column names */
type Names
/** Tuple of all column names */
type Columns
/** Tuple of all columns, with literal type; typed later in `selectDynamic` */
val all: Any
transparent inline def selectDynamic(name: String): Any =
if name == Table.ExceptMethodName then
(toExclude: List[String] | String) =>
toExclude match
case str: String => Table.Columns(columns.get.filterNot(c => c.n == str))
case list: List[String] => Table.Columns(columns.get.filterNot(c => list.contains(c.n)))
else if name == Table.SelectMethodName then
(toInclude: List[String]) =>
Table.Columns(toInclude.flatMap(name => columns.get.find(tc => tc.n == name)))
else if name == Table.AllMethodName then
all
else columns.get.find(_.n == name).get
object in:
object count:
def f: Fragment[Void] =
sql"SELECT COUNT(*) FROM ${tableName.f}"
def q: Query[Void, Long] =
f.query(int8)
object Table:
case class Columns(get: NonEmptyList[TypedColumn[?]]):
def f: Fragment[Void] =
sql"#${get.toList.map(_.n).mkString(", ")}"
def as(short: String): List[String] =
get.map(_.n).toList.map(c => s"$short.$c")
object Columns:
def apply(list: List[TypedColumn[?]]): Columns =
Columns(NonEmptyList.fromListUnsafe(list))
val ExceptMethodName: "except" = "except"
type Except[A] = A | List[A] => Columns
val AllMethodName: "all" = "all"
val SelectMethodName: "select" = "select"
type Select[A] = List[A] => Columns
final case class Config(service: Option[String], table: Option[String]):
def getTableName(typeName: String): String =
val t = table.getOrElse(typeName)
val path = service match
case Some(s) => NonEmptyList.of(s, t)
case None => NonEmptyList.one(t)
NameStrategy.FullSnake.transform(path)
object Config:
inline def default: Config = Config(None, None)
def apply(service: String, table: String): Config =
Config(Some(service), Some(table))
enum NameStrategy:
case Snake
case FullSnake
case Camel
case FullCamel
def transform(path: NonEmptyList[String]): String =
this match
case Snake =>
snakeCase(path.last)
case FullSnake =>
path.toList.map(snakeCase).mkString("_")
case Camel =>
path.last
case FullCamel =>
path match
case NonEmptyList(head, tail) =>
(head :: tail.map(_.capitalize)).mkString("")
final case class TypedColumn[C](n: String):
infix def eql(c: Encoder[C]): Fragment[C] = sql"#$n = $c"
def as(table: String): TypedColumn[C] =
this.copy(n = s"$table.$n")
def f: Fragment[Void] =
sql"#$n"
def currentTimestamp(using C =:= Option[LocalDateTime]): Fragment[Void] =
sql"#$n = current_timestamp"
def increment[A](using C =:= IronType[Int, A]): Fragment[Void] =
sql"#$n = #$n + 1"
opaque type Name = String
object Name:
def apply(str: String): Name = str
extension (name: Name)
def unbox: String = name
def f: Fragment[Void] = sql"#${name}"
def as(short: String): Fragment[Void] = sql"#${name} AS #${short}"
transparent inline def build[T](inline config: Config)(using Mirror.ProductOf[T]): Table =
${ buildFromBlockImpl[T]('config) }
def buildFromBlockImpl[T: Type](using quotes: Quotes)(configExpr: Expr[Config]): Expr[Table] =
import quotes.reflect.*
val mirror = MacroMirror.summonProduct[T]
val tableName = '{ ${configExpr}.getTableName(${Expr(mirror.label)}) }
val nameTypeMap =
def go(root: List[String])(ls: List[(String, TypeRepr)]): List[(NonEmptyList[String], TypeRepr)] =
ls.flatMap { (label, tpe) =>
tpe.asType match
case '[t] =>
Expr.summon[Primitive[t]] match
case Some(_) =>
Some(NonEmptyList(label, root).reverse -> tpe)
case None =>
MacroMirror.summon[t] match
case Right(pm: MacroMirror.ProductMacroMirror[quotes.type, t]) =>
go(label :: root)(pm.elemLabels.zip(pm.elemTypes))
case _ =>
report.errorAndAbort(s"Couldn't synthesize instance neither for Primitive nor Table for ${tpe.show}")
}
val zipped = mirror.elemLabels.zip(mirror.elemTypes)
NonEmptyList.fromList(go(Nil)(zipped).map((path, tpe) => NameStrategy.Snake.transform(path) -> tpe)) match
case Some(nel) => nel
case None => report.errorAndAbort("Could not derive columns. A Table must contain at least one Column")
val columns = Expr.ofList(nameTypeMap.toList.map { (n, t) =>
t.asType match
case '[tpe] =>
val name = Expr(n)
'{ new TypedColumn[tpe](${name}) }
})
val columnNames = nameTypeMap.map(_._1)
if (columnNames.distinct.length != columnNames.length)
then report.errorAndAbort(s"Not all column names are unique (${columnNames.toList.mkString(", ")})")
val columnName: TypeRepr = columnNames match
case NonEmptyList(a, b :: rest) =>
rest
.foldLeft(OrType(ConstantType(StringConstant(a)), ConstantType(StringConstant(b))))
.apply((orType, name) => OrType(orType, ConstantType(StringConstant(name))))
case NonEmptyList(a, Nil) =>
ConstantType(StringConstant(a))
val columnNamesTuple = Expr.ofTupleFromSeq(columnNames.toList.map(name => Expr(name)))
val refinementWithColumns = nameTypeMap
.foldLeft(TypeRepr.of[Table])
.apply { case (acc, (name, tpr)) =>
tpr.asType match
case '[tpe] =>
Refinement(
parent = acc,
name = name,
info = TypeRepr.of[TypedColumn].appliedTo(TypeRepr.of[tpe])
)
}
val refinementWithExcept = Refinement(parent = refinementWithColumns, name = ExceptMethodName, TypeRepr.of[Table.Except].appliedTo(columnName))
val refinementWithSelect = Refinement(parent = refinementWithExcept, name = SelectMethodName, TypeRepr.of[Table.Select].appliedTo(columnName))
val refinementFinal = Refinement(parent = refinementWithSelect, name = AllMethodName, columnNamesTuple.asTerm.tpe)
(refinementFinal.asType, columnNamesTuple.asTerm.tpe.asType, columnName.asType) match
case ('[refinementType], '[allTuple], '[union]) =>
'{
(new Table(Name(${tableName}), Columns(NonEmptyList.fromListUnsafe(${columns}))) {
val all = ${columnNamesTuple}
}).asInstanceOf[Table { type Names = union; type Columns = allTuple } & refinementType]
}
def snakeCase(str: String): String =
str
.replaceAll("([A-Z]+)([A-Z][a-z])", "$1_$2")
.replaceAll("([a-z\\d])([A-Z])", "$1_$2")
.toLowerCase
@kitlangton
Copy link

https://gist.github.com/chuwy/8f664c5e2ac57d4513702beb9e5f261e#file-table-scala-L162
MacroMirror has an elems field that exposes a combination label/type, along with some other conveniences, which might be useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment