Created
July 28, 2014 16:17
-
-
Save benjaminjackman/9746aa99d2c8350df493 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forCaseClass[A <: Product : c.WeakTypeTag](c: Context): c.Expr[SerClass[A]] = { | |
import c.universe._ | |
val tpe = weakTypeOf[A] | |
val sym = tpe.typeSymbol.asClass | |
val tpeStr = tpe.toString | |
if (!sym.isCaseClass) { | |
c.error(c.enclosingPosition, "Cannot create SerClass for non-case class") | |
return c.Expr[SerClass[A]](q"null") | |
} | |
val schemaName = { | |
if (tpeStr.contains("[")) { | |
q"None" | |
} else { | |
q"Some($tpeStr)" | |
} | |
} | |
val accessors = (tpe.declarations collect { | |
case acc: MethodSymbol if acc.isCaseAccessor => acc | |
}).toList | |
val serClasses = for { | |
(accessor, idx) <- accessors.zipWithIndex | |
} yield { | |
val serClassName = c.universe.newTermName("serClass" + idx) | |
val tpeN = c.universe.newTypeName(accessor.typeSignature.toString.drop(3)) | |
//Using parse here is really dirty, however I can't get it to work with quasiquotes | |
//Generic parameter names will be treated as concrete, unless c.universe.newTypeName(...) is used | |
//However that will cause problems if that is called on something like Option[Int] It doesn't | |
//see to be able to make a typeName properly when the type has [, ] in it. Which is probably | |
//correct. | |
c.parse(s"val $serClassName = implicitly[cgta.serland.SerClass[$tpeN]]") | |
} | |
val serWriteStmts = for { | |
(accessor, idx) <- accessors.zipWithIndex | |
} yield { | |
val fieldName = accessor.name | |
val fieldNameString = fieldName.toString | |
val fieldNameTerm = c.universe.newTermName(fieldNameString) | |
val serWritableName = c.universe.newTermName("serClass" + idx) | |
val tpeN = accessor.typeSignature | |
q""" | |
try { | |
out.writeFieldBegin($fieldNameString, $idx) | |
$serWritableName.write(a.$fieldNameTerm, out) | |
out.writeFieldEnd() | |
} catch { | |
case e : Throwable => cgta.serland.WRITE_ERROR("at Field[" + $fieldNameString + "]", e) | |
} | |
""" | |
} | |
val serReadStmts = for { | |
(accessor, idx) <- accessors.zipWithIndex | |
} yield { | |
val fieldName = accessor.name | |
val fieldNameString = fieldName.toString | |
val serReadableName = c.universe.newTermName("serClass" + idx) | |
val tmpName = c.universe.newTermName("tmp" + idx) | |
q""" | |
val $tmpName = try { | |
in.readFieldBegin($fieldNameString, $idx) | |
val t = $serReadableName.read(in) | |
in.readFieldEnd() | |
t | |
} catch { | |
case e : Throwable => cgta.serland.READ_ERROR("at Field[" + $fieldNameString + "]", e) | |
} | |
""" | |
} | |
val tmps = for { | |
(accessor, idx) <- accessors.zipWithIndex.toList | |
} yield { | |
val tmpName = c.universe.newTermName("tmp" + idx) | |
q"$tmpName" | |
} | |
val serGen = { | |
val fqs = for { | |
(accessor, idx) <- accessors.zipWithIndex | |
} yield { | |
s"tmp$idx <- serClass$idx.gen" | |
} | |
val fqsStr = fqs.mkString("", ";", "") | |
val tStr = tpe.toString | |
val gensStr = tmps.mkString("", ",", "") | |
val code = s"for($fqsStr) yield {new $tStr($gensStr)}" | |
c.parse(code) | |
} | |
val fieldSchemas = { | |
for { | |
(accessor, idx) <- accessors.zipWithIndex.toList | |
} yield { | |
val fieldName = accessor.name | |
val fieldNameString = fieldName.toString | |
val serReadableName = c.universe.newTermName("serClass" + idx) | |
q"XField($fieldNameString, $idx, $serReadableName.schema.schemaRef)" | |
} | |
} | |
val result = q""" | |
new cgta.serland.SerClass[$tpe] { | |
..$serClasses | |
override def schema : cgta.serland.SerSchema = { | |
import cgta.serland.SerSchemas.{XStruct, XField} | |
XStruct($schemaName, Vector(..$fieldSchemas)) | |
} | |
override def read(in: cgta.serland.SerInput) : $tpe = { | |
try { | |
in.readStructBegin() | |
..$serReadStmts | |
in.readStructEnd() | |
new $tpe(..$tmps) | |
} catch { | |
case e : Throwable => cgta.serland.READ_ERROR("at SerCaseClass[" + $tpeStr + "]", e) | |
} | |
} | |
override def write(a: $tpe, out: cgta.serland.SerOutput) { | |
try { | |
out.writeStructBegin() | |
..$serWriteStmts | |
out.writeStructEnd() | |
} catch { | |
case e : Throwable => cgta.serland.WRITE_ERROR("at SerCaseClass[" + $tpeStr + "]", e) | |
} | |
} | |
override def gen : org.scalacheck.Gen[$tpe] = $serGen | |
} | |
""" | |
c.Expr[SerClass[A]](result) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment