Created
February 17, 2016 17:49
-
-
Save lambdaknight/2517ef99445578b388d1 to your computer and use it in GitHub Desktop.
Enumish Macros: Macro Annotation For Generating an Enum-ish Object
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.annotation.{StaticAnnotation, compileTimeOnly} | |
import scala.language.experimental.macros | |
import scala.reflect.macros.whitebox | |
trait EnumishValue extends Ordered[EnumishValue] with java.lang.Comparable[EnumishValue] { | |
def id: Int | |
def compare(that: EnumishValue) = this.id - that.id | |
} | |
@compileTimeOnly("enable macro paradise to expand macro annotations") | |
class enumish extends StaticAnnotation { | |
def macroTransform(annottees: Any*): Any = macro enumishMacro.impl | |
} | |
object enumishMacro { | |
def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { | |
import c.universe._ | |
def getParentNames(tree: ImplDef) = tree.impl.parents.flatMap { | |
case Select(_, name) => Some(name) | |
case Apply(Ident(name), _) => Some(name) | |
case Ident(name) => Some(name) | |
case _ => None | |
} | |
def findEnumishValueClass(trees: Seq[Tree]) = trees.find({ | |
case (classDef: ClassDef) if classDef.mods.hasFlag(Flag.SEALED) && getParentNames(classDef).contains(TypeName("EnumishValue")) => true | |
case _ => false | |
}).asInstanceOf[Option[ClassDef]].map(_.name) | |
def findValues(trees: Seq[Tree], valueTypeName: TypeName) = trees.flatMap { | |
case (modDef: ModuleDef) if getParentNames(modDef).contains(valueTypeName) => Some(modDef.name) | |
case _ => None | |
} | |
def enumerifyObject(enumObj: ModuleDef) = { | |
val q"object $enumName extends ..$bases { ..$body }" = enumObj | |
val enumerantValueClassDef = findEnumishValueClass(body) | |
enumerantValueClassDef.map({ enumerantType => | |
val enumerants = findValues(body, enumerantType) | |
if(enumerants.isEmpty) | |
c.warning(c.enclosingPosition, "No enumerants found.") | |
val enumerantValue = Ident(enumerantType) | |
q""" | |
object $enumName extends ..$bases { | |
..$body | |
private implicit val ordering = new scala.math.Ordering[$enumerantValue] { | |
override def compare(x: $enumerantValue, y: $enumerantValue): Int = x.compare(y) | |
} | |
val values: scala.collection.immutable.TreeSet[$enumerantValue] = scala.collection.immutable.TreeSet.apply(..${enumerants.map(Ident(_))}) | |
def withName(s: String) = s match { | |
case ..${enumerants.map(e => cq"${e.toString} => ${Ident(e)}")} | |
} | |
def apply(id: Int) = values.find(x => x.id == id).getOrElse(throw new NoSuchElementException("Key not found: " + id)) | |
} | |
""" | |
}).getOrElse(c.abort(c.enclosingPosition, "Sealed class implementing EnumishValue not found.")) | |
} | |
annottees.map(_.tree) match { | |
case (modDef: ModuleDef) :: Nil => { | |
c.Expr(q"${enumerifyObject(modDef)}") | |
} | |
case _ => c.abort(c.enclosingPosition, "@enumish annotation can only be applied to objects.") | |
} | |
} | |
} | |
/* | |
Usage: | |
@enumish | |
object EnumishEnum { | |
sealed abstract class EnumishEnum(val id: Int) extends EnumishValue | |
case object Foo extends EnumishEnum(1) | |
case object Bar extends EnumishEnum(2) | |
case object Baz extends EnumishEnum(3) | |
} | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment