Skip to content

Instantly share code, notes, and snippets.

@yilinwei
Last active June 15, 2016 01:14
Show Gist options
  • Save yilinwei/27a5b22c346c1b6957041957adf96d90 to your computer and use it in GitHub Desktop.
Save yilinwei/27a5b22c346c1b6957041957adf96d90 to your computer and use it in GitHub Desktop.
package iliad
package platform
import scala.annotation.StaticAnnotation
import scala.language.experimental.macros
final class free[T] extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro FreeMacro.mkTransform
}
import scala.reflect.macros.whitebox
final class FreeMacro(val c: whitebox.Context) extends SymbolMacro {
import c.universe._
lazy val _objectMethods: Set[Symbol] = weakTypeOf[Object].members.toSet
case class CaseMethod(uniqueName: String, methodName: TermName, method: MethodSymbol, returnUnit: Boolean)
def nextParam(m: MethodSymbol): Tree = if(m.returnType == typeOf[Unit]) q"val next: Next" else q"val onNext: ${m.returnType} => Next"
def asArgs(m: MethodSymbol): Seq[Tree] = m.paramLists.flatMap(_.map(p => q"${p.asTerm.name}"))
def nextArg(m: MethodSymbol): Tree = if(m.returnType == typeOf[Unit]) q"next" else q"onNext"
def unapplyParams(n: Int): List[Bind] =
//We can't use cq here because the wildcard for the unapply needs to be bounded.
(0 to n).map(n => Bind(TermName("arg" + n), Ident(termNames.WILDCARD))).toList
def rangeParam(n: Int): Tree = {
val name = TermName("arg" + n)
q"$name"
}
def rangeParams(n: Int): List[Tree] = (0 to n).map(rangeParam).toList
//Have to make case statements, functor and execute instance
def caseMethodsFold(single: CaseMethod => Tree, poly: CaseMethod => Tree, polyAggregate: (Seq[CaseMethod], Seq[Tree]) => Seq[Tree])(methodsByName: Map[TermName, List[MethodSymbol]]): Seq[Tree] = {
methodsByName.flatMap { case (tn, ms) =>
val isPoly = ms.size > 1
val mn = tn.decodedName.toString
if(isPoly) {
val polyCases = ms.zipWithIndex.map { case (m, idx) =>
val ru = m.returnType == typeOf[Unit]
val cm = CaseMethod(s"$mn$idx", TermName(mn), m, ru)
(cm, single(cm))
}
polyAggregate(polyCases.map(_._1), polyCases.map(_._2))
} else {
val m = ms.head
val ru = m.returnType == typeOf[Unit]
Seq(single(CaseMethod(mn, TermName(mn), m, ru)))
}
}.toSeq
}
def capitalize(str: String): String = editFirst(_.toUpper, str)
def lowerFirst(str: String): String = editFirst(_.toLower, str)
def editFirst(f: Char => Char, str: String): String = f(str(0)) + str.tail
def caseDef(tpe: TypeName, cm: CaseMethod): Tree = {
val next = nextParam(cm.method)
val params = methodSymbolParamTree(cm.method) :+ next
q"case class ${TypeName(capitalize(cm.uniqueName))}[Next](..$params) extends $tpe[Next]"
}
def mkExecStatement(tpe: TypeName, cm: CaseMethod): Tree = {
val cs = TermName(capitalize(cm.uniqueName))
val params = cm.method.paramLists.head.length
val bind = unapplyParams(params)
val unapply = pq"$cs(..$bind)"
val args = rangeParams(params - 1)
val last = rangeParam(params)
val next = if (cm.returnUnit) last else q"$last.apply(result)"
cq"""$unapply =>
val result = runner.${TermName(cm.methodName.decodedName.toString)}(..$args)
$next
"""
}
def caseDefs(tpe: TypeName) = caseMethodsFold(caseDef(tpe, _), caseDef(tpe, _), (cases, trees) => {
val objectBody = cases.map {
case CaseMethod(un, mn, m, ru) =>
val next = nextParam(m)
val params = methodSymbolParamTree(m) :+ next
val cn = TypeName(capitalize(un))
q"def apply[Next](..$params): $cn[Next] = ${tpe.toTermName}.${cn.toTermName}.apply(..${asArgs(m) ++ Seq(nextArg(m))})"
}
trees ++ Seq( q"""
object ${TermName(capitalize(cases.head.methodName.decodedName.toString))} {
..$objectBody
}
""")
}) _
def execDefs(tpe: TypeName) = caseMethodsFold(mkExecStatement(tpe, _), mkExecStatement(tpe, _), (_, trees) => trees ) _
def mkFunctorStatement(cm: CaseMethod): Tree = {
val cs = TermName(capitalize(cm.uniqueName))
val params = cm.method.paramLists.head.length
val bind = unapplyParams(params)
val unapply = pq"$cs(..$bind)"
val last = rangeParam(params)
val args = rangeParams(params - 1) :+ (if(cm.returnUnit) q"f($last)" else q"$last.andThen(f)")
cq"$unapply => $cs(..$args)"
}
def functorCases = caseMethodsFold(mkFunctorStatement, mkFunctorStatement, (_, trees) => trees) _
def mkFree(tpt: TypeName, existingBody: List[Tree]): Tree = {
val tpe = annotatedType
/*Not entirely sure why, but vals generated from the type annotation don't seem
to have all the type information is expected, which means we need to check the paramLists...*/
val methods = tpe.members.filter(t => t.isMethod && !_objectMethods.contains(t) && t.asMethod.paramLists.nonEmpty && t.name != TermName("$init$"))
.groupBy(_.name).map { case (mn, ms) => mn.toTermName -> ms.map(_.asMethod).toList }
q"""
trait $tpt[Next] {
def run(runner: $tpe): Next = ${tpt.toTermName}.run(this, runner)
def toFree: _root_.cats.free.Free[$tpt, Next] = _root_.cats.free.Free.liftF(this)
}
object ${tpt.toTermName} {
..${caseDefs(tpt)(methods)}
def run[A](fa: $tpt[A], runner: $tpe): A = fa match {
case ..${execDefs(tpt)(methods)}
}
def runner(runner: $tpe): _root_.cats.~>[$tpt, _root_.cats.Id] = new (_root_.cats.~>[$tpt, _root_.cats.Id]) {
def apply[A](fa: $tpt[A]): _root_.cats.Id[A] = run(fa, runner)
}
implicit val ${TermName(s"${lowerFirst(tpt.decodedName.toString)}Functor")}: _root_.cats.Functor[$tpt] = new _root_.cats.Functor[$tpt] {
def map[A, B](fa: $tpt[A])(f: A => B): $tpt[B] = fa match { case ..${functorCases(methods)} }
}
..$existingBody
}
"""
}
def mkTransform(annottees: Expr[Any]*): Tree = {
annottees.map(_.tree) match {
case List(q"abstract trait $tpt[$_]") => mkFree(tpt, List())
case List(q"abstract trait $tpt[$_]", q"object $_ { ..$body }") => mkFree(tpt, body)
case _ => c.abort(c.enclosingPosition, "Cannot free trait as annotated type is not in the expected format")
}
}
}
package iliad
import shapeless._
import scala.annotation.implicitNotFound
import scala.reflect.ClassTag
@implicitNotFound(msg = "Cannot find PolyTC typeclass for ${TC} in list ${L}.")
trait PolyTC[L <: HList, F[_], TC[_[_]], A, B] {
def apply(a: F[A])(f: (F[A], TC[F]) => B): B
}
object PolyTC {
type Aux[L <: HList, F[_], TC[_[_]], A, B] = PolyTC[L, F, TC, A, B]
implicit def base[F[_], TC[_[_]], A, B]: Aux[HNil, F, TC, A, B] = new Aux[HNil, F, TC, A, B] {
def apply(a: F[A])(f: (F[A], TC[F]) => B): B =
throw new IllegalStateException(s"Could not find class matching $f within poly list!" )
}
implicit def recurse[H[_] <: F[_], T <: HList, F[_], TC[_[_]], A, B](implicit tc: TC[H], tail: Aux[T, F, TC, A, B], ct: ClassTag[H[_]]): Aux[H[A] :: T, F, TC, A, B] = new Aux[H[A] :: T, F, TC, A, B] {
def apply(fa: F[A])(g: (F[A], TC[F]) => B): B =
if(fa.getClass == ct.runtimeClass)
g(fa, tc.asInstanceOf[TC[F]])
else
tail(fa)(g)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment