Skip to content

Instantly share code, notes, and snippets.

@lambdaknight
Created February 17, 2016 17:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lambdaknight/2517ef99445578b388d1 to your computer and use it in GitHub Desktop.
Save lambdaknight/2517ef99445578b388d1 to your computer and use it in GitHub Desktop.
Enumish Macros: Macro Annotation For Generating an Enum-ish Object
import scala.annotation.{StaticAnnotation, compileTimeOnly}
import scala.language.experimental.macros
import scala.reflect.macros.whitebox
trait EnumishValue extends Ordered[EnumishValue] with java.lang.Comparable[EnumishValue] {
def id: Int
def compare(that: EnumishValue) = this.id - that.id
}
@compileTimeOnly("enable macro paradise to expand macro annotations")
class enumish extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro enumishMacro.impl
}
object enumishMacro {
def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
def getParentNames(tree: ImplDef) = tree.impl.parents.flatMap {
case Select(_, name) => Some(name)
case Apply(Ident(name), _) => Some(name)
case Ident(name) => Some(name)
case _ => None
}
def findEnumishValueClass(trees: Seq[Tree]) = trees.find({
case (classDef: ClassDef) if classDef.mods.hasFlag(Flag.SEALED) && getParentNames(classDef).contains(TypeName("EnumishValue")) => true
case _ => false
}).asInstanceOf[Option[ClassDef]].map(_.name)
def findValues(trees: Seq[Tree], valueTypeName: TypeName) = trees.flatMap {
case (modDef: ModuleDef) if getParentNames(modDef).contains(valueTypeName) => Some(modDef.name)
case _ => None
}
def enumerifyObject(enumObj: ModuleDef) = {
val q"object $enumName extends ..$bases { ..$body }" = enumObj
val enumerantValueClassDef = findEnumishValueClass(body)
enumerantValueClassDef.map({ enumerantType =>
val enumerants = findValues(body, enumerantType)
if(enumerants.isEmpty)
c.warning(c.enclosingPosition, "No enumerants found.")
val enumerantValue = Ident(enumerantType)
q"""
object $enumName extends ..$bases {
..$body
private implicit val ordering = new scala.math.Ordering[$enumerantValue] {
override def compare(x: $enumerantValue, y: $enumerantValue): Int = x.compare(y)
}
val values: scala.collection.immutable.TreeSet[$enumerantValue] = scala.collection.immutable.TreeSet.apply(..${enumerants.map(Ident(_))})
def withName(s: String) = s match {
case ..${enumerants.map(e => cq"${e.toString} => ${Ident(e)}")}
}
def apply(id: Int) = values.find(x => x.id == id).getOrElse(throw new NoSuchElementException("Key not found: " + id))
}
"""
}).getOrElse(c.abort(c.enclosingPosition, "Sealed class implementing EnumishValue not found."))
}
annottees.map(_.tree) match {
case (modDef: ModuleDef) :: Nil => {
c.Expr(q"${enumerifyObject(modDef)}")
}
case _ => c.abort(c.enclosingPosition, "@enumish annotation can only be applied to objects.")
}
}
}
/*
Usage:
@enumish
object EnumishEnum {
sealed abstract class EnumishEnum(val id: Int) extends EnumishValue
case object Foo extends EnumishEnum(1)
case object Bar extends EnumishEnum(2)
case object Baz extends EnumishEnum(3)
}
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment