Skip to content

Instantly share code, notes, and snippets.

@casualjim
Last active December 10, 2015 02:09
scalaVersion := "2.10.0"
name := "reflect"
organization := "com.github.casualjim"
libraryDependencies += "org.specs2" %% "specs2" % "1.13" % "test"
libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-reflect" % _)
scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature", "-language:implicitConversions")
package object playground {
implicit def map2valueProvider(m: Map[String, Any]): ValueProvider[Map[String, Any]] = new MapValueReader(m)
object PassThroughSeparator extends Separator("", "")
object DotSeparator extends Separator(".", "")
object SquareBracketsSeparator extends Separator("[", "]")
object BracketsSeparator extends Separator("(", ")")
object ForwardSlashSeparator extends Separator("/", "")
object BackslashSeparator extends Separator("\\", "")
}
package playground
import org.specs2.mutable.Specification
import java.util.{Calendar, Date}
class MutablePerson {
var name: String = _
var age: Int = _
var dateOfBirth: Date = _
override def toString() = "MutablePerson(name: %s, age: %d, dateOfBirth: %s)".format(name, age, dateOfBirth)
override def equals(obj: Any): Boolean = obj match {
case o: MutablePerson => name == o.name && age == o.age && dateOfBirth == o.dateOfBirth
case _ => false
}
}
case class Person(name: String, age: Int, dateOfBirth: Date)
case class Thing(name: String, age: Int, dateOfBirth: Option[Date], createdAt: Date = new Date)
case class AnotherThing(name: String, dateOfBirth: Option[Date], age: Int, createdAt: Date = new Date)
case class Record(id: Int, data: String, createdAt: Date = new Date)
case class OtherRecord(id: Int, data: String, createdAt: Date = new Date) {
def this(id: Int) = this(id, "")
def this() = this(0, "")
}
case class PersonWithThing(name: String, age: Int, thing: Thing)
class ReflectionSpec extends Specification {
"Reflective access" should {
"create a person when all the fields are provided" in {
val cal = Calendar.getInstance()
cal.set(cal.get(Calendar.YEAR) - 33, 1, 1, 0, 0, 0)
val expected = Person("Timmy", 33, cal.getTime)
val res = Reflective.bind[Person](Map("name" -> "Timmy", "age" -> 33, "dateOfBirth" -> cal.getTime))
res must_== expected
}
"create a mutable person when all the fields are provided" in {
val cal = Calendar.getInstance()
cal.set(cal.get(Calendar.YEAR) - 33, 1, 1, 0, 0, 0)
val expected = new MutablePerson() //("Timmy", 33, cal.getTime)
expected.name = "Timmy"
expected.age = 33
expected.dateOfBirth = cal.getTime
val res = Reflective.bind[MutablePerson](Map("name" -> "Timmy", "age" -> 33, "dateOfBirth" -> cal.getTime))
res must_== expected
}
"create a record when all the fields are provided" in {
val expected = record(225)
val res = Reflective.bind[Record](Map("id" -> expected.id, "data" -> expected.data, "createdAt" -> expected.createdAt))
res must_== expected
}
"create a record when createdAt is not provided" in {
val expected = record(225)
val actual = Reflective.bind[Record](Map("id" -> expected.id, "data" -> expected.data))
actual.id must_== expected.id
actual.data must_== expected.data
}
"create a thing when only name and age are provided" in {
val expected = Thing("a thing", 2, None)
val actual = Reflective.bind[Thing](Map("name" -> expected.name, "age" -> expected.age))
actual.name must_== expected.name
actual.age must_== expected.age
actual.dateOfBirth must beNone
}
"create a thing when all fields are provided" in {
val cal = Calendar.getInstance()
cal.set(cal.get(Calendar.YEAR) - 2, 1, 1, 0, 0, 0)
val createdAt = new Date
val expected = Thing("a thing", 2, None)
val actual = Reflective.bind[Thing](Map("name" -> expected.name, "age" -> expected.age, "dateOfBirth" -> cal.getTime, "createdAt" -> createdAt))
actual.name must_== expected.name
actual.age must_== expected.age
actual.dateOfBirth must beSome(cal.getTime)
actual.createdAt must_== createdAt
}
"create another thing when only name and age are provided" in {
val expected = AnotherThing("another thing", None, 2)
val actual = Reflective.bind[AnotherThing](Map("name" -> expected.name, "age" -> expected.age))
actual.name must_== expected.name
actual.age must_== expected.age
actual.dateOfBirth must beNone
}
"create an other record when only id is provided" in {
val expected = new OtherRecord(303)
val actual = Reflective.bind[OtherRecord](Map("id" -> 303))
actual.id must_== expected.id
actual.data must_== expected.data
}
}
private[this] def record(id: Int) = Record(id, s"A $id record")
}
package playground
import scala.reflect.runtime.{currentMirror => cm, universe}
import scala.reflect.runtime.universe._
import scala.reflect._
object Reflective {
private[this] val reflective = cm.reflect(cm.reflectModule(cm.moduleSymbol(getClass)).instance)
private[this] def bindType(tpe: Type, values: ValueProvider[_]) {
val meth = reflective.symbol.typeSignature.member(newTermName("bind")).asMethod
meth.typeSignature.substituteTypes(meth.typeParams, List(tpe))
reflective.reflectMethod(meth).apply(values)
}
def bind[T](values: ValueProvider[_])(implicit ct: ClassTag[T]): T = synchronized {
val klazz = cm.reflectClass(cm.classSymbol(ct.runtimeClass))
val csym = klazz.symbol
val im = if (csym.isCaseClass) {
val modul = csym.companionSymbol.asModule
cm reflect (cm reflectModule modul).instance
} else null
val (ctor, ctorParams) = pickConstructor(csym, values.keySet)
val (defaults, probablyRequired) = ctorParams.zipWithIndex partition (_._1.asTerm.isParamWithDefault)
val (options, required) = probablyRequired partition (s => s._1.asTerm.typeSignature <:< typeOf[Option[_]])
def valueFor(sym: (Symbol, Int)) = {
val decName = sym._1.name.decoded.trim
if (values.isComplex(decName)) (null, sym._2)
else (values(decName), sym._2)
}
def optionalValueFor(sym: (Symbol, Int)) = (values.get(sym._1.name.decoded.trim), sym._2)
def defaultValueFor(sym: (Symbol, Int)) = (values.get(sym._1.name.decoded.trim) getOrElse {
val ts = im.symbol.typeSignature
val defarg = ts member newTermName(s"apply$$default$$${sym._2+1}")
if (defarg != NoSymbol)
(im reflectMethod defarg.asMethod)()
else
throw new IllegalArgumentException(s"${sym._1.name.decoded}: ${sym._1.typeSignature.toString}")
}, sym._2)
val remainingValues = values -- ctorParams.map(_.name.decoded)
val toset = (required map valueFor) ::: (options map optionalValueFor) ::: (defaults map defaultValueFor)
val obj = klazz.reflectConstructor(ctor)(toset.sortBy(_._2).map(_._1):_*).asInstanceOf[T]
setFields(obj, remainingValues)
}
private[this] def pickConstructor(clazz: ClassSymbol, argNames: Set[String]): (MethodSymbol, List[Symbol]) = {
val ctors = clazz.typeSignature.member(nme.CONSTRUCTOR).asTerm.alternatives.map(_.asMethod).sortBy(-_.paramss.sortBy(-_.size).headOption.getOrElse(Nil).size)
val zipped = ctors zip (ctors map (ctor => pickConstructorArgs(ctor.paramss, argNames)))
zipped collectFirst {
case (m: MethodSymbol, Some(args)) => (m, args)
} getOrElse (throw new RuntimeException(s"Couldn't find a constructor for ${clazz.name.decoded} and args: [${argNames.mkString(", ")}]"))
}
private[this] def pickConstructorArgs(candidates: List[List[Symbol]], argNames: Set[String]): Option[List[Symbol]] = {
val ctors = candidates.sortBy(-_.size)
def isRequired(item: Symbol) = {
val sym = item.asTerm
!(sym.isParamWithDefault || sym.typeSignature <:< typeOf[Option[_]])
}
def matchingRequired(plist: List[Symbol]) = {
val required = plist filter isRequired
required.size <= argNames.size && required.forall(s => argNames.contains(s.name.decoded))
}
ctors find matchingRequired
}
def getFields[T](obj: T)(implicit mf: ClassTag[T]): Seq[(String, Any)] = {
val im = cm.reflect(obj)
val ms = im.symbol
(for {
decl <- ms.typeSignature.declarations.map(_.asTerm)
if decl.isVar
fm = im.reflectField(decl)
} yield (decl.name.decoded.trim, fm.get)).toSeq
}
def setFields[S, T : ClassTag](obj: T, values: ValueProvider[S]): T = {
val im = cm.reflect(obj)
val ms = im.symbol
ms.typeSignature.declarations.map(_.asTerm).filter(_.isVar) foreach { f =>
val fm = im.reflectField(f)
values get f.name.decoded.trim foreach fm.set
}
obj
}
}
package playground
abstract class Separator(val beginning: String, end: String) {
val hasBeginning = beginning != null && beginning.trim.nonEmpty
val hasEnd = end != null && end.trim.nonEmpty
def wrap(part: String, prefix: String = "") = {
val hasPrefix = prefix != null && prefix.trim.nonEmpty
if (hasPrefix) prefix + wrapped(part)
else part
}
def wrapped(part: String) = {
val sb = new StringBuilder
if (hasBeginning && !part.startsWith(beginning))
sb.append(beginning)
sb.append(part)
if (hasEnd && !part.endsWith(end))
sb.append(end)
sb.toString()
}
def stripFirst(key: String) = {
val endIndex = if (hasEnd) key.indexOf(end) else -1
def rest = {
val realEnd = endIndex + end.size
val hasMore = key.size > (realEnd + 1)
if (hasMore) key.substring(realEnd) else ""
}
if (hasBeginning && key.startsWith(beginning)) {
if (hasEnd && endIndex > -1) {
key.substring(beginning.size, endIndex) + rest
} else key.substring(beginning.size)
} else if (hasBeginning && hasEnd && endIndex > -1 && endIndex < key.indexOf(beginning)) {
key.substring(0, endIndex) + rest
} else key
}
def topLevelOnly(key: String, prefix: String = "") = {
val path = stripPrefix(key, prefix)
val startIndex = path.indexOf(beginning)
if (startIndex > -1)
path.substring(0, startIndex)
else {
val endIndex = path.indexOf(end)
if (hasEnd && endIndex > -1)
path.substring(0, endIndex)
else path
}
}
def stripPrefix(path: String, prefix: String) = {
val hasPrefix = prefix != null && prefix.trim.nonEmpty
if (hasPrefix && path.startsWith(prefix)) {
stripFirst(path.substring(prefix.length))
}
else stripFirst(path)
}
}
package playground
import util.control.Exception._
trait ValueProvider[S] {
def prefix: String
def separator: Separator
protected def data: S
def read(key: String): Either[Throwable, Option[Any]]
def get(key: String): Option[Any] = read(key).fold(_ => None, identity)
def apply(key: String): Any = read(key).fold(throw _, _ getOrElse (throw new RuntimeException(s"No entry found for $key")))
def forPrefix(key: String): ValueProvider[S]
def values: S
def keySet: Set[String]
def --(keys: Iterable[String]): ValueProvider[S]
def isComplex(key: String): Boolean
}
class MapValueReader(protected val data: Map[String, Any], val prefix: String = "", val separator: Separator = DotSeparator) extends ValueProvider[Map[String, Any]] {
def read(key: String): Either[Throwable, Option[Any]] = allCatch either { data get separator.wrap(key, prefix) }
def forPrefix(key: String): ValueProvider[Map[String, Any]] = new MapValueReader(data, separator.wrap(key), separator)
lazy val values: Map[String, Any] = stripPrefix(data)
def keySet: Set[String] = values.keySet map (separator.topLevelOnly(_, prefix))
def --(keys: Iterable[String]) = new MapValueReader(data -- keys.map(separator.wrap(_, prefix)), prefix, separator)
def isComplex(key: String) = {
val pref = separator.wrap(key, prefix)
if (pref != null && pref.trim.nonEmpty) {
data exists {
case (k, _) =>
separator.stripPrefix(k, prefix).contains(separator.beginning) && k.startsWith(pref + separator.beginning)
}
} else false
}
private[this] def stripPrefix(d: Map[String, Any]): Map[String, Any] = {
if (prefix != null && prefix.trim.nonEmpty) {
d collect {
case (k, v) if k startsWith (prefix + separator.beginning) => separator.stripPrefix(k, prefix) -> v
}
} else d
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment