Skip to content

Instantly share code, notes, and snippets.

@lukestewart13
Last active July 6, 2023 12:18
Show Gist options
  • Save lukestewart13/a50d1392f0233b831f7211b20c2cd682 to your computer and use it in GitHub Desktop.
Save lukestewart13/a50d1392f0233b831f7211b20c2cd682 to your computer and use it in GitHub Desktop.
Macro annotation to generate new partial model inside companion object with optional parameters of annotated model. Also generates a default Play JSON deserializer.
import scala.annotation.StaticAnnotation
import scala.language.experimental.macros
import scala.reflect.macros.whitebox
/**
* Example:
* {{{
* @GeneratePartialModel(fieldsToRestrict = "id")
* case class User(id: Int, name: String, email: Option[String])
*
* //macro generates this (or adds it to the existing companion object):
* object User {
*
* case class PartialUser(id: Option[Int] = None, name: Option[String] = None, email = Option[Option[String]] = None) {
* def update(model: User): User = {
* model.copy(id = id.getOrElse(model.id), name = name.getOrElse(model.name), email = email.getOrElse(model).email)
* }
* }
*
* case class RestrictedPartialUser(name: Option[String] = None, email = Option[Option[String]] = None) {
* def update(model: User): User = {
* model.copy(name = name.getOrElse(model.name), email = email.getOrElse(model).email)
* }
* }
*
* implicit val restrictedReads: Reads[RestrictedPartialUser] = Json.reads[RestrictedPartialUser]
* }
* }}}
*/
class GeneratePartialModels(fieldsToRestrict: String*) extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro GeneratePartialModelsImpl.impl
}
object GeneratePartialModelsImpl {
def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
def extractAnnotationParameters(tree: Tree): Seq[String] = tree match {
case q"new $name( ..$params )" => params.flatMap {
case q"$name = ${Literal(Constant(field: String))}" => Some(field)
case Literal(Constant(field: String)) => Some(field)
case _ => None
}
case _ => Nil
}
def extractCaseClassParts(classDecl: ClassDef) = classDecl match {
case q"@..$annots case class $className(..$fields) extends ..$parents { ..$body }" =>
(annots, className, fields, parents, body)
case _ => c.abort(c.enclosingPosition, "Macro applies to case classes only")
}
def extractCompanionObjectParts(objDef: ModuleDef) = objDef match {
case q"object $className extends ..$parents { ..$body }" => (className, parents, body)
}
def extractOptionalParams(fields: Seq[Tree], excludeValueNames: Seq[String]) = fields.asInstanceOf[List[ValDef]].flatMap {
case q"$accessor val $vname: $tpe" if excludeValueNames.contains(vname.asInstanceOf[TermName].decodedName.toString) => None
case q"$accessor val $vname: $tpe" => Some(q"$accessor val $vname: Option[$tpe] = None")
case q"$accessor val $vname: $tpe = $default" => Some(q"$accessor val $vname: Option[$tpe] = None")
}
def extractPartialClass(optionalParams: List[Tree], className: TypeName, partialClassName: TypeName) = {
val copies = optionalParams.map {
case q"$accessor val $vname: Option[$tpe] = None" => q"$vname = $vname.getOrElse(model.$vname)"
}
q"""case class $partialClassName ( ..$optionalParams ) {
def update(model: $className): $className = {
model.copy(
..$copies
)
}
}
"""
}
def extractRestrictedClass(optionalParams: List[Tree], className: TypeName, restrictedClassName: TypeName) = {
val copies = optionalParams.map {
case q"$accessor val $vname: Option[$tpe] = None" => q"$vname = $vname.getOrElse(model.$vname)"
}
q"""case class $restrictedClassName ( ..$optionalParams ) {
def update(model: $className): $className = {
model.copy(
..$copies
)
}
}
"""
}
def extractRestrictedClassReads(restrictedClassName: TypeName) = {
q"""implicit val restrictedPartialReads: play.api.libs.json.Reads[$restrictedClassName] = play.api.libs.json.Json.reads[$restrictedClassName]"""
}
def modifiedDeclaration(classDecl: ClassDef, optCompanionDecl: Option[ModuleDef]) = {
val excludeValueNames = extractAnnotationParameters(c.prefix.tree)
val (_, className, fields, _, _) = extractCaseClassParts(classDecl)
val optionalPartialParams = extractOptionalParams(fields, Seq.empty)
val optionalRestrictedParams = extractOptionalParams(fields, excludeValueNames)
val partialClassName = TypeName(s"Partial$className")
val restrictedClassName = TypeName(s"RestrictedPartial$className")
val partialClass = extractPartialClass(optionalPartialParams, className, partialClassName)
val restrictedClass = extractRestrictedClass(optionalRestrictedParams, className, restrictedClassName)
val restrictedClassReads = extractRestrictedClassReads(restrictedClassName)
optCompanionDecl match {
case Some(companionDecl) =>
val (compName, compParents, compBody) = extractCompanionObjectParts(companionDecl)
c.Expr[Any](
q"""$classDecl; object $compName extends ..$compParents {
$partialClass
$restrictedClass
$restrictedClassReads
..$compBody
}
"""
)
case None =>
val compName = className.asInstanceOf[TypeName].toTermName
c.Expr[Any](
q"""$classDecl; object $compName {
$partialClass
$restrictedClass
$restrictedClassReads
}
"""
)
}
}
annottees.map(_.tree).toList match {
case List(classDecl: ClassDef) => modifiedDeclaration(classDecl, None)
case List(classDecl: ClassDef, companionDecl: ModuleDef) => modifiedDeclaration(classDecl, Some(companionDecl))
case _ => c.abort(c.enclosingPosition, "Invalid annottee")
}
}
}
@Ivoyaa
Copy link

Ivoyaa commented Jul 6, 2023

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment