Created November 5, 2015 16:24
Macro that auto-creates lifted functions for Free monads.
package free
import scala.annotation.{StaticAnnotation, compileTimeOnly}
import scala.language.experimental.macros
import scala.language.higherKinds
import scala.reflect.macros.whitebox
* Usage:
* <pre>
* sealed trait Op[+A]
* case class MyOp(a: String) extends Op[Unit]
* &commat;AddLiftingFunctions[Op]('Mon) object monadic
* import monadic._
* val a: Mon[Unit] = myOp("hello")
* </pre>
* @tparam Op sealed trait of the Operations
@compileTimeOnly("enable macro paradise to expand macro annotations")
class AddLiftingFunctions[Op[_]](typeName: Symbol) extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro FreeMacro.addLiftFunctionsAnnotation_impl
* Usage:
* <pre>
* sealed trait Op[+A]
* case class MyOp(a: String) extends Op[Unit]
* val monadic = FreeMacro.liftFunctions[Op]('Mon)
* import monadic._
* val a: Mon[Unit] = myOp("hello")
* </pre>
object FreeMacro {
def liftFunctions[F[_]](typeName: Symbol): Any = macro FreeMacro.liftFunctions_impl[F]
def liftFunctionsVampire[F[_]](typeName: Symbol): Any = macro FreeMacro.liftFunctionsVampire_impl[F]
//Private stuff below
//Vampire-body, see
class vampire(tree: Any) extends StaticAnnotation
class FreeMacro(val c: whitebox.Context) {
import c.universe._
def liftFunctions_impl[F[_]](typeName: Expr[Any])(implicit t: c.WeakTypeTag[F[_]]) =
generateAnonClass[F](typeName, false)
def liftFunctionsVampire_impl[F[_]](typeName: Expr[Any])(implicit t: c.WeakTypeTag[F[_]]) =
generateAnonClass[F](typeName, true)
private def generateAnonClass[F[_]](typeNameExpr: Expr[Any], vampire: Boolean)(implicit t: c.WeakTypeTag[F[_]]) = {
val Apply(_, Literal(Constant(typeName: String)) :: Nil) = typeNameExpr.tree
val mod = generate(TermName(typeName), t.tpe.typeSymbol, false)
c.Expr(q"new { ..$mod }")
def addLiftFunctionsAnnotation_impl(annottees: Expr[Any]*): Expr[Any] = {
val q"new $_[$opIdent](${typeNameTree: Tree}).macroTransform(..$_)" = c.macroApplication
val opBase = c.typecheck(q"???.asInstanceOf[$opIdent[Unit]]").tpe.typeSymbol
val Apply(_, Literal(Constant(typeNameString: String)) :: Nil) = typeNameTree
val typeName = TermName(typeNameString)
val mod = match {
case ClassDef(mods, name, tparams, Template(parents, self, body)) :: rest ⇒ //class/trait
val (initBody, restBody) = body.splitAt(1)
val t2 = Template(parents, self, initBody ++ generate(typeName, opBase) ++ restBody)
ClassDef(mods, name, tparams, t2) :: rest
case ModuleDef(mods, name, Template(parents, self, body)) :: rest ⇒ // object
val t2 = Template(parents, self, generate(typeName, opBase) ++ body)
ModuleDef(mods, name, t2) :: rest
case a :: rest ⇒
c.abort(c.enclosingPosition, "AddLiftingFunctions annotation only supported on classes and objects")
private def generate(name: Name, opBase: Symbol, useVampire: Boolean = false): List[Tree] = {
val freeTypeName = name.toTypeName
val freeTypeTree =
q"""type $freeTypeName[A] =[({ type λ[α] =[${opBase.asType}, α]})#λ, A]"""
val monadDeclTree =
q"""implicit val monad =[({ type λ[α] =[${opBase.asType}, α]})#λ]"""
val opClass = opBase.asClass
if (!opClass.isSealed)
c.abort(c.enclosingPosition, s"The base class ${} of the free monad is not sealed")
if (opClass.knownDirectSubclasses.isEmpty)
c.abort(c.enclosingPosition, s"The base class ${} of the free monad has no subclasses. " +
s"If you're sure you have subclasses ans use @AddLiftingFunctions then this is a compilation order problem. " +
s"In that case please use FreeMonad.liftFunctions.")
val functions = {
case s: ClassSymbol ⇒
forImplementation(opBase.asType.toType, freeTypeName, useVampire)(s)
freeTypeTree :: monadDeclTree :: functions
/** Creates "myOp(text: String): FT[Unit]" from "case class MyOp(text: String)" */
private def forImplementation(base: Type, freeType: TypeName, useVampire: Boolean)(opImpl: ClassSymbol): Tree = {
// inspired by
val name = TermName(classNameFunctionName(
//TODO handle MyOperation[A] extends Op[A]
val A = opImpl.typeSignature.baseType(base.typeSymbol).typeArgs.head
val companion = opImpl.companion
val params = caseClassFields(opImpl.typeSignature)
if (useVampire) {
val paramDefs = {
case ((_, tpe), index) ⇒
val name = TermName("in" + (index + 1))
q"""$name: $tpe"""
if (paramDefs.size > 5) c.abort(c.enclosingPosition, s"More parameters in ${} than supported " +
"by the FreeMacro. Please tell the maintainer to extend it.")
val vampire = TermName(s"vampire${paramDefs.size}_impl")
def $name(..$paramDefs): $freeType[$A] = macro$vampire"""
} else {
val paramNames =
val paramDefs = { p ⇒ q"""${p._1}: ${p._2}""" }
q"""def $name(..$paramDefs): $freeType[$A] =$companion(..$paramNames))"""
//Vampire Methods to avoid structural type warning
def vampire0_impl() =
def vampire1_impl(in1: Expr[Any]) =
def vampire2_impl(in1: Expr[Any], in2: Expr[Any]) =
q"$companionFromVampire($in1, $in2))"
def vampire3_impl(in1: Expr[Any], in2: Expr[Any], in3: Expr[Any]) =
q"$companionFromVampire($in1, $in2, $in3))"
def vampire4_impl(in1: Expr[Any], in2: Expr[Any], in3: Expr[Any], in4: Expr[Any]) =
q"$companionFromVampire($in1, $in2, $in3, $in4))"
def vampire5_impl(in1: Expr[Any], in2: Expr[Any], in3: Expr[Any], in4: Expr[Any], in5: Expr[Any]) =
q"$companionFromVampire($in1, $in2, $in3, $in4, $in5))"
private def companionFromVampire = macroAnnotation[vampire].tree.children.tail.head
/** Current macro Annotation. */
private def macroAnnotation[T](implicit t: WeakTypeTag[T]): Annotation = {
_.tree.tpe <:< t.tpe
).headOption.getOrElse(c.abort(c.enclosingPosition, s"Annotation ${} not found."))
/** Converts MyOperation to myOperation */
private def classNameFunctionName(className: String): String = className.head.toLower + className.tail
/** Extracts [(text, String), (number, Int) from "case class MyClass(text: String, number: Int)" */
private def caseClassFields(tpe: Type): Iterable[(TermName, Type)] = {
tpe.decls.collect {
case accessor: MethodSymbol if accessor.isCaseAccessor ⇒
accessor.typeSignature match {
case NullaryMethodType(returnType) ⇒ (, returnType)
