Skip to content

Instantly share code, notes, and snippets.

@jestan
Forked from dadrox/ScalaEnum.scala
Created February 10, 2012 05:53
Show Gist options
  • Save jestan/1787035 to your computer and use it in GitHub Desktop.
Save jestan/1787035 to your computer and use it in GitHub Desktop.
DIY Scala Enums (with optional exhaustiveness checking, name introspection, and boilerplate reverse lookup)
trait Enum { //DIY enum type
import java.util.concurrent.atomic.AtomicReference //Concurrency paranoia
type EnumVal <: Value //This is a type that needs to be found in the implementing class
def withName(name: String): Option[EnumVal] = values.find(_.name == name)
def withNameIgnoringCase(name: String): Option[EnumVal] = values.find(_.name.equalsIgnoreCase(name))
private def lazyName(requestingInstance: EnumVal): String = {
getClass().getMethods()
.filter(method => classOf[Value].isAssignableFrom(method.getReturnType()) && method.getParameterTypes().isEmpty)
.find(_.invoke(this).asInstanceOf[EnumVal] eq requestingInstance) match {
case Some(method) => method.getName
case None => throw new Error("Yup, this Enum stuff is jacked... this is a bug. Please issue a ticket :(")
}
}
private val _values = new AtomicReference(Vector[EnumVal]()) //Stores our enum values
//Adds an EnumVal to our storage, uses CCAS to make sure it's thread safe, returns the ordinal
private final def addEnumVal(newVal: EnumVal): Int = {
import _values.{ get, compareAndSet => CAS }
val oldVec = get
val newVec = oldVec :+ newVal
if ((get eq oldVec) && CAS(oldVec, newVec)) newVec.indexWhere(_ eq newVal) else addEnumVal(newVal)
}
def values: Vector[EnumVal] = _values.get //Here you can get all the enums that exist for this type
//This is the class that we need to extend our EnumVal type with, it does the book-keeping for us
protected trait Value { self: EnumVal => //Enforce that no one mixes in Value in a non-EnumVal type
final val ordinal = addEnumVal(this) //Adds the EnumVal and returns the ordinal
lazy val name: String = lazyName(this)
override def toString = if (this.isInstanceOf[Product]) name + "[" + scala.runtime.ScalaRunTime._toString(this.asInstanceOf[Product]) + "]" else name
override def equals(other: Any) = this eq other.asInstanceOf[AnyRef]
override def hashCode = 31337 * (this.getClass.## + name.## + ordinal)
}
}
//And here's how to use it, if you want compiler exhaustiveness checking
object Foos extends Enum {
sealed trait EnumVal extends Value /*{ you can define your own methods etc here }*/
val F = new EnumVal {}
val X = new EnumVal {}
}
/**
scala> Foos.values.find(_.name == "F")
res3: Option[Foos.EnumVal] = Some(F)
scala> Foos.X.ordinal
res4: Int = 1
scala> def doSmth(foo: Foos.EnumVal) = foo match {
case Foos.X => println("pigdog")
}
<console>:10: warning: match is not exhaustive!
missing combination $anon$1
missing combination $anon$2
scala> def doSmth(foo: Foos.EnumVal) = foo match {
case Foos.X => println("pigdog")
case Foos.F => println("dogpig")
}
doSmth: (foo: Foos.EnumVal)Unit
**/
//But if you don't care about getting exhaustiveness warnings, you can do:
//Allows addition of more static data
//object Foos extends Enum {
// case class EnumVal private[Foos](code: Int) extends Value /* { you can define your own methods and stuff here } */
//
// val F = EnumVal(123)
// val X = EnumVal(456)
//}
/**
Which is a bit less boilerplatey.
Cheers,
**/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment