Skip to content

Instantly share code, notes, and snippets.

@calvinlfer
Last active November 29, 2021 21:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save calvinlfer/023cbf481a51790b9c279fb5e298e5b7 to your computer and use it in GitHub Desktop.
Save calvinlfer/023cbf481a51790b9c279fb5e298e5b7 to your computer and use it in GitHub Desktop.
ZIO Schema: high speed decoder (pattern match on Schema itself and handle each subtype)
package dev.zio.schema.example.example8.advanced
import dev.zio.schema.example.example8.Json
import zio.Chunk
import zio.schema.{Schema, StandardType}
import java.util.UUID
import scala.collection.immutable.ListMap
trait FasterDecoder[A] { self =>
def decode(in: Json): Either[String, A]
def map[B](f: A => B): FasterDecoder[B] =
(in: Json) => self.decode(in).map(f)
}
object FasterDecoder {
def schemaDecoder[A](schema: Schema[A]): FasterDecoder[A] = schema match {
case p @ Schema.Primitive(standardType, _) => primitiveDecoder(standardType, p)
case s @ Schema.Sequence(_, _, _, _) => sequenceDecoder(s)
case Schema.Optional(schema, _) => optional(schemaDecoder(schema))
case l @ Schema.Lazy(_) => schemaDecoder(l.schema)
case m @ Schema.Meta(_, _) => notSupported("meta", m)
case cc1 @ Schema.CaseClass1(_, _, _, _) => caseClass1Decoder(cc1)
case cc2 @ Schema.CaseClass2(_, _, _, _, _, _) => caseClass2Decoder(cc2)
case cc3 @ Schema.CaseClass3(_, _, _, _, _, _, _, _) => caseClass3Decoder(cc3)
case cc4 @ Schema.CaseClass4(_, _, _, _, _, _, _, _, _, _) => caseClass4Decoder(cc4)
// stopped because you can start to see the pattern
case gcc @ Schema.GenericRecord(_, _) => caseClassNDecoder(gcc)
case e1 @ Schema.Enum1(_, _) => enumDecoder(e1.case1)
case e2 @ Schema.Enum2(_, _, _) => enumDecoder(e2.case1, e2.case2)
case e3 @ Schema.Enum3(_, _, _, _) => enumDecoder(e3.case1, e3.case2, e3.case3)
case e4 @ Schema.Enum4(_, _, _, _, _) => enumDecoder(e4.case1, e4.case2, e4.case3, e4.case4)
// stopped because you can start to see the pattern
case eN: Schema.EnumN[sealedTrait, subType] => enumDecoder(eN.caseSet.toSeq: _*)
case t: Schema.Transform[a, b] => transformDecoder(t)
case t @ Schema.Tuple(left, right, _) => tupleDecoder(t)
case e @ Schema.EitherSchema(left, right, _) => eitherDecoder(e)
case f @ Schema.Fail(message, annotations) => notSupported("failed schema", f)
case other => notSupported(s"$other", other)
}
def transformDecoder[A, B](schema: Schema.Transform[A, B]): FasterDecoder[B] = {
val decoderForA = schemaDecoder(schema.codec)
(in: Json) => decoderForA.decode(in).flatMap(schema.f)
}
def caseClass1Decoder[FieldOfCaseClass, CaseClass](
schema: Schema.CaseClass1[FieldOfCaseClass, CaseClass]
): FasterDecoder[CaseClass] = {
case j @ Json.JObj(_) =>
caseClassDecodeHelper(j)(schema.field.label, schemaDecoder(schema.field.schema))
.map((fieldOfCaseClass: FieldOfCaseClass) => schema.construct(fieldOfCaseClass))
case other =>
Left(s"Expected a JSON object but got $other instead")
}
def caseClass2Decoder[F1, F2, CC](
schema: Schema.CaseClass2[F1, F2, CC]
): FasterDecoder[CC] = {
case j @ Json.JObj(_) =>
for {
f1 <- caseClassDecodeHelper(j)(schema.field1.label, schemaDecoder(schema.field1.schema))
f2 <- caseClassDecodeHelper(j)(schema.field2.label, schemaDecoder(schema.field2.schema))
} yield schema.construct(f1, f2)
case other =>
Left(s"Expected a JSON object but got $other instead")
}
def caseClass3Decoder[F1, F2, F3, CC](
schema: Schema.CaseClass3[F1, F2, F3, CC]
): FasterDecoder[CC] = {
case j @ Json.JObj(_) =>
for {
f1 <- caseClassDecodeHelper(j)(schema.field1.label, schemaDecoder(schema.field1.schema))
f2 <- caseClassDecodeHelper(j)(schema.field2.label, schemaDecoder(schema.field2.schema))
f3 <- caseClassDecodeHelper(j)(schema.field3.label, schemaDecoder(schema.field3.schema))
} yield schema.construct(f1, f2, f3)
case other =>
Left(s"Expected a JSON object but got $other instead")
}
def caseClass4Decoder[F1, F2, F3, F4, CC](
schema: Schema.CaseClass4[F1, F2, F3, F4, CC]
): FasterDecoder[CC] = {
case j @ Json.JObj(_) =>
for {
f1 <- caseClassDecodeHelper(j)(schema.field1.label, schemaDecoder(schema.field1.schema))
f2 <- caseClassDecodeHelper(j)(schema.field2.label, schemaDecoder(schema.field2.schema))
f3 <- caseClassDecodeHelper(j)(schema.field3.label, schemaDecoder(schema.field3.schema))
f4 <- caseClassDecodeHelper(j)(schema.field4.label, schemaDecoder(schema.field4.schema))
} yield schema.construct(f1, f2, f3, f4)
case other =>
Left(s"Expected a JSON object but got $other instead")
}
def caseClassNDecoder(schema: Schema.GenericRecord): FasterDecoder[ListMap[String, _]] = {
case obj @ Json.JObj(_) =>
schema.fieldSet.toChunk.foldLeft(
Right(ListMap.empty): Either[String, ListMap[String, _]]
) { (acc, nextField) =>
for {
acc <- acc
key = nextField.label
value <- caseClassDecodeHelper(obj)(key, schemaDecoder(nextField.schema))
} yield acc.updated(key, value)
}
case other =>
Left(s"Expected a JSON object but got $other instead")
}
def caseClassDecodeHelper[A](
json: Json.JObj
)(key: String, decoder: FasterDecoder[A]): Either[String, A] =
json.map.get(key) match {
case Some(json) => decoder.decode(json)
case None => Left(s"Failed to find $key in $json")
}
def tupleDecoder[A, B](schema: Schema.Tuple[A, B]): FasterDecoder[(A, B)] = {
case obj @ Json.JObj(_) =>
for {
l <- schemaDecoder(schema.left).decode(obj)
r <- schemaDecoder(schema.right).decode(obj)
} yield (l, r)
case other =>
Left(s"Expected a JSON object but got $other instead")
}
def eitherDecoder[A, B](schema: Schema.EitherSchema[A, B]): FasterDecoder[Either[A, B]] =
json =>
schemaDecoder(schema.left)
.decode(json)
.map(Left(_))
.orElse(schemaDecoder(schema.right).decode(json).map(Right(_)))
def enumDecoder[SealedTrait](
schemas: Schema.Case[_, SealedTrait]*
): FasterDecoder[SealedTrait] = {
case obj @ Json.JObj(map) =>
schemas
.find { subtypeCase =>
map
.get("type")
.collect { case Json.JStr(typeName) => typeName == subtypeCase.id }
.getOrElse(false)
}
.map { subtypeCase =>
schemaDecoder(subtypeCase.codec)
.decode(obj)
.asInstanceOf[Either[String, SealedTrait]]
}
.getOrElse(
Left(
s"type did not match one of the subtypes of the sealed trait ${schemas.map(_.id).mkString("(", ", ", ")")}"
)
)
case other =>
Left(s"Expected a JSON object but got $other instead")
}
def sequenceDecoder[Collection, Element](
schema: Schema.Sequence[Collection, Element]
): FasterDecoder[Collection] = {
case Json.JArr(xs) =>
val elementDecoder = schemaDecoder(schema.schemaA)
val zero: Either[String, Chunk[Element]] = Right(Chunk.empty)
val decodedElements = xs.foldLeft(zero) { (acc, next) =>
for {
elem <- elementDecoder.decode(next)
acc <- acc
} yield acc :+ elem
}
decodedElements.map(schema.fromChunk)
case other =>
Left(s"Expected array but got $other")
}
def primitiveDecoder[A](
standardType: StandardType[A],
schema: Schema.Primitive[A]
): FasterDecoder[A] =
standardType match {
case StandardType.UnitType =>
unitDecoder
case StandardType.StringType =>
stringDecoder
case StandardType.BoolType =>
boolDecoder
case StandardType.ShortType =>
numDecoder.map(_.toShort)
case StandardType.IntType =>
numDecoder.map(_.toInt)
case StandardType.LongType =>
numDecoder.map(_.toLong)
case StandardType.FloatType =>
numDecoder.map(_.toFloat)
case StandardType.DoubleType =>
numDecoder
case StandardType.BinaryType =>
stringDecoder.map(s => Chunk.fromArray(s.getBytes()))
case StandardType.CharType =>
stringDecoder.map(_.head)
case StandardType.UUIDType =>
stringDecoder.map(UUID.fromString(_))
case StandardType.BigDecimalType =>
numDecoder
.map(BigDecimal.decimal)
.map(_.bigDecimal)
case StandardType.BigIntegerType =>
numDecoder.map(num => BigInt(num.toLong).bigInteger)
case StandardType.DayOfWeekType =>
notSupported("day-of-week", schema)
case StandardType.Month =>
notSupported("month", schema)
case StandardType.MonthDay =>
notSupported("month-day", schema)
case StandardType.Period =>
notSupported("period", schema)
case StandardType.Year =>
notSupported("year", schema)
case StandardType.YearMonth =>
notSupported("year-month", schema)
case StandardType.ZoneId =>
notSupported("zone-id", schema)
case StandardType.ZoneOffset =>
notSupported("zone-offset", schema)
case StandardType.Duration(temporalUnit) =>
notSupported("duration", schema)
case StandardType.Instant(formatter) =>
notSupported("instant", schema)
case StandardType.LocalDate(formatter) =>
notSupported("local-date", schema)
case StandardType.LocalTime(formatter) =>
notSupported("local-time", schema)
case StandardType.LocalDateTime(formatter) =>
notSupported("local-date-time", schema)
case StandardType.OffsetTime(formatter) =>
notSupported("offset-time", schema)
case StandardType.OffsetDateTime(formatter) =>
notSupported("offset-date-time", schema)
case StandardType.ZonedDateTime(formatter) =>
notSupported("zoned-date-time", schema)
}
val stringDecoder: FasterDecoder[String] = {
case Json.JStr(s) => Right(s)
case other => Left(s"Expected string but got $other")
}
val unitDecoder: FasterDecoder[Unit] =
_ => Right(())
val boolDecoder: FasterDecoder[Boolean] = {
case Json.JBool(b) => Right(b)
case other => Left(s"Expected boolean but got $other")
}
val numDecoder: FasterDecoder[Double] = {
case Json.JNum(num) => Right(num)
case other => Left(s"Expected number but got $other")
}
def optional[A](in: FasterDecoder[A]): FasterDecoder[Option[A]] =
json =>
in.decode(json) match {
case Left(_) => Right(None)
case Right(v) => Right(Some(v))
}
def notSupported[A](typeName: String, schema: Schema[A]): FasterDecoder[A] = { json =>
Left(s"Sorry but $json cannot be converted into $typeName with schema $schema")
}
}
package dev.zio.schema.example.example8.advanced
import dev.zio.schema.example.example8.Json
import zio.schema.{FieldSet, Schema, StandardType}
import scala.collection.immutable.ListMap
sealed trait FasterEncoder[A] { self =>
def encode(in: A): Json
def contramap[B](f: B => A): FasterEncoder[B] = new FasterEncoder[B] {
override def encode(in: B): Json = self.encode(f(in))
}
}
object FasterEncoder {
def schemaEncoder[A](implicit schema: Schema[A]): FasterEncoder[A] = schema match {
case e1 @ Schema.Enum1(_, _) => enum1Encoder(e1)
case Schema.Enum2(c1, c2, _) => enumEncoder(c1, c2)
case Schema.Enum3(c1, c2, c3, _) => enumEncoder(c1, c2, c3)
case Schema.Enum4(c1, c2, c3, c4, _) => enumEncoder(c1, c2, c3, c4)
case Schema.Enum5(c1, c2, c3, c4, c5, _) => enumEncoder(c1, c2, c3, c4, c5)
// you get the pattern
case e @ Schema.EnumN(_, _) => enumEncoder(e.caseSet.toSeq: _*)
case Schema.CaseClass1(_, f, _, e) =>
caseClassEncoder((f, e))
case Schema.CaseClass2(_, f1, f2, _, e1, e2) =>
caseClassEncoder((f1, e1), (f2, e2))
case Schema.CaseClass3(_, f1, f2, f3, _, e1, e2, e3) =>
caseClassEncoder((f1, e1), (f2, e2), (f3, e3))
case Schema.CaseClass4(_, f1, f2, f3, f4, _, e1, e2, e3, e4) =>
caseClassEncoder((f1, e1), (f2, e2), (f3, e3), (f4, e4))
case Schema.CaseClass5(_, f1, f2, f3, f4, f5, _, e1, e2, e3, e4, e5) =>
caseClassEncoder((f1, e1), (f2, e2), (f3, e3), (f4, e4), (f5, e5))
case Schema.GenericRecord(fieldSet, _) =>
genericRecordEncoder(fieldSet)
case seq @ Schema.Sequence(_, _, _, _) =>
sequenceEncoder(seq)
case t @ Schema.Transform(_, _, _, _) =>
transformEncoder(t)
case Schema.Primitive(standardType, _) => primitiveEncoder(standardType)
case Schema.Lazy(s) => schemaEncoder(s())
case Schema.Tuple(left, right, _) =>
tupleEncoder(left, right)
case Schema.EitherSchema(left, right, _) => eitherEncoder(left, right)
case Schema.Optional(codec, _) => ???
case Schema.Fail(message, annotations) => ???
case Schema.Meta(ast, annotations) => ???
}
def primitiveEncoder[A](standardType: StandardType[A]): FasterEncoder[A] =
standardType match {
case StandardType.UnitType => unitEncoder
case StandardType.StringType => stringEncoder
case StandardType.BoolType => boolEncoder
case StandardType.ShortType => doubleEncoder.contramap[Short](_.toDouble)
case StandardType.IntType => doubleEncoder.contramap[Int](_.toDouble)
case StandardType.LongType => doubleEncoder.contramap(_.toDouble)
case StandardType.FloatType => doubleEncoder.contramap(_.toDouble)
case StandardType.DoubleType => doubleEncoder
case StandardType.BinaryType => stringEncoder.contramap(_.toString)
case StandardType.CharType => stringEncoder.contramap(_.toString)
case StandardType.UUIDType => stringEncoder.contramap(_.toString)
case StandardType.BigDecimalType => stringEncoder.contramap(_.toString)
case StandardType.BigIntegerType => stringEncoder.contramap(_.toString)
case StandardType.DayOfWeekType => stringEncoder.contramap(_.toString)
case StandardType.Month => stringEncoder.contramap(_.toString)
case StandardType.MonthDay => stringEncoder.contramap(_.toString)
case StandardType.Period => stringEncoder.contramap(_.toString)
case StandardType.Year => stringEncoder.contramap(_.toString)
case StandardType.YearMonth => stringEncoder.contramap(_.toString)
case StandardType.ZoneId => stringEncoder.contramap(_.toString)
case StandardType.ZoneOffset => stringEncoder.contramap(_.toString)
case StandardType.Duration(temporalUnit) => stringEncoder.contramap(_.toString)
case StandardType.Instant(formatter) => stringEncoder.contramap(_.toString)
case StandardType.LocalDate(formatter) => stringEncoder.contramap(_.toString)
case StandardType.LocalTime(formatter) => stringEncoder.contramap(_.toString)
case StandardType.LocalDateTime(formatter) => stringEncoder.contramap(_.toString)
case StandardType.OffsetTime(formatter) => stringEncoder.contramap(_.toString)
case StandardType.OffsetDateTime(formatter) => stringEncoder.contramap(_.toString)
case StandardType.ZonedDateTime(formatter) => stringEncoder.contramap(_.toString)
}
def tupleEncoder[A, B](left: Schema[A], right: Schema[B]): FasterEncoder[(A, B)] =
new FasterEncoder[(A, B)] {
override def encode(in: (A, B)): Json =
Json.JObj(
Map(
"_1" -> schemaEncoder(left).encode(in._1),
"_2" -> schemaEncoder(right).encode(in._2)
)
)
}
def eitherEncoder[A, B](left: Schema[A], right: Schema[B]): FasterEncoder[Either[A, B]] =
new FasterEncoder[Either[A, B]] {
override def encode(in: Either[A, B]): Json = in match {
case Left(value) => schemaEncoder(left).encode(value)
case Right(value) => schemaEncoder(right).encode(value)
}
}
val unitEncoder: FasterEncoder[Unit] = new FasterEncoder[Unit] {
override def encode(in: Unit): Json = Json.JObj(Map.empty)
}
val stringEncoder: FasterEncoder[String] = new FasterEncoder[String] {
override def encode(in: String): Json = Json.JStr(in)
}
val boolEncoder: FasterEncoder[Boolean] = new FasterEncoder[Boolean] {
override def encode(in: Boolean): Json = Json.JBool(in)
}
val doubleEncoder: FasterEncoder[Double] = new FasterEncoder[Double] {
override def encode(in: Double): Json = Json.JNum(in)
}
def enum1Encoder[SealedTrait, Subtype <: SealedTrait](
schema: Schema.Enum1[Subtype, SealedTrait]
): FasterEncoder[SealedTrait] =
new FasterEncoder[SealedTrait] {
override def encode(in: SealedTrait): Json = {
val typeInfo = schema.case1.id
schema.case1
.deconstruct(in)
.map { value =>
schemaEncoder(schema.case1.codec).encode(value)
}
.collect { case Json.JObj(map) => Json.JObj(map + ("type" -> Json.JStr(typeInfo))) }
.getOrElse(Json.JObj(Map("type" -> Json.JStr(typeInfo))))
}
}
def enumEncoder[General](
cases: Schema.Case[_, General]*
): FasterEncoder[General] =
new FasterEncoder[General] {
override def encode(in: General): Json = {
val index = cases.indexWhere(_.deconstruct(in).isDefined)
if (index == -1) { Json.JObj(Map.empty) }
else {
val subTypeCase = cases(index)
val specific = subTypeCase.unsafeDeconstruct(in)
schemaEncoder(subTypeCase.asInstanceOf[Schema[Any]]).encode(specific) match {
case Json.JObj(map) => Json.JObj(map + ("type" -> Json.JStr(subTypeCase.id)))
case _ => Json.JObj(Map("type" -> Json.JStr(subTypeCase.id)))
}
}
}
}
def caseClassEncoder[A](fieldAndExtractors: (Schema.Field[_], A => Any)*): FasterEncoder[A] =
new FasterEncoder[A] {
override def encode(in: A): Json =
Json.JObj(
fieldAndExtractors.foldLeft(Map.empty[String, Json]) { case (acc, (field, extractor)) =>
val fieldName = field.label
val json = schemaEncoder(field.schema.asInstanceOf[Schema[Any]]).encode(extractor(in))
acc + (fieldName -> json)
}
)
}
def genericRecordEncoder(fieldSet: FieldSet): FasterEncoder[ListMap[String, _]] =
new FasterEncoder[ListMap[String, _]] {
override def encode(in: ListMap[String, _]): Json =
Json.JObj(
fieldSet.toChunk.foldLeft(Map.empty[String, Json]) { (acc, next) =>
in.get(next.label) match {
case Some(value) =>
acc + (next.label -> schemaEncoder(next.schema.asInstanceOf[Schema[Any]])
.encode(value))
case None =>
acc
}
}
)
}
def sequenceEncoder[Collection, Element](
schema: Schema.Sequence[Collection, Element]
): FasterEncoder[Collection] =
new FasterEncoder[Collection] {
override def encode(in: Collection): Json = {
val elementEncoder = schemaEncoder(schema.schemaA)
val jsonChunk = schema.toChunk(in).map(elementEncoder.encode)
Json.JArr(jsonChunk.toList)
}
}
def transformEncoder[A, B](schema: Schema.Transform[A, B]): FasterEncoder[B] =
new FasterEncoder[B] {
override def encode(in: B): Json =
schema
.g(in)
.map(a => schemaEncoder(schema.codec).encode(a))
.getOrElse(Json.JObj(Map.empty))
}
}
package dev.zio.schema.example.example8.advanced
import dev.zio.schema.example.example8.Json
import zio.schema.{DeriveSchema, Schema}
sealed trait Color
object Color {
case object Red extends Color
case object Green extends Color
case object Blue extends Color
final case class Custom(r: Int, g: Int, b: Int) extends Color
implicit val schemaColor: Schema[Color] = DeriveSchema.gen[Color]
val colorDecoder: FasterDecoder[Color] = FasterDecoder.schemaDecoder(schemaColor)
}
final case class Person(name: String, age: Int, preferences: List[String])
object Person {
implicit val schemaPerson: Schema[Person] = DeriveSchema.gen[Person]
val personDecoder: FasterDecoder[Person] = FasterDecoder.schemaDecoder(schemaPerson)
val personEncoder: FasterEncoder[Person] = FasterEncoder.schemaEncoder(schemaPerson)
val listPersonDecoder: FasterDecoder[List[Person]] =
FasterDecoder.schemaDecoder(Schema.list(schemaPerson))
}
object FasterUsage extends App {
val p = Person("cal", 30, List("silent", "consistent", "dependable"))
val sameP =
Person.personDecoder
.decode(
Json.JObj(
Map(
"name" -> Json.JStr("cal"),
"age" -> Json.JNum(30),
"preferences" -> Json.JArr(List("silent", "consistent", "dependable").map(Json.JStr))
)
)
)
println(sameP == Right(p))
val jsonPersons =
Json.JArr(
List(
Json.JObj(
Map(
"name" -> Json.JStr("cal"),
"age" -> Json.JNum(30),
"preferences" -> Json.JArr(
List("silent", "consistent", "dependable", "technical").map(Json.JStr)
)
)
),
Json.JObj(
Map(
"name" -> Json.JStr("char"),
"age" -> Json.JNum(28),
"preferences" -> Json.JArr(
List("silent", "consistent", "dependable", "creative").map(Json.JStr)
)
)
)
)
)
println(Person.listPersonDecoder.decode(jsonPersons))
println(
Color.colorDecoder.decode {
Json.JObj(
Map(
"type" -> Json.JStr("Custom"),
"r" -> Json.JNum(0),
"g" -> Json.JNum(0),
"b" -> Json.JNum(0)
)
)
}
)
println {
val p = Person("cal", 30, List("polite", "kind"))
val encoded = Person.personEncoder.encode(p)
Person.personDecoder.decode(encoded) == Right(p)
}
println {
Color.colorDecoder.decode {
Json.JObj(Map("type" -> Json.JStr("Red")))
}
}
println {
FasterDecoder
.schemaDecoder(DeriveSchema.gen[(Person, Color)])
.decode(
Json.JObj(
Map(
"_1" -> Json.JObj(
Map(
"name" -> Json.JStr("char"),
"age" -> Json.JNum(28),
"preferences" -> Json.JArr(
List("silent", "consistent", "dependable", "creative").map(Json.JStr)
)
)
),
"_2" -> Json.JObj(
Map(
"type" -> Json.JStr("Custom"),
"r" -> Json.JNum(255),
"g" -> Json.JNum(255),
"b" -> Json.JNum(255)
)
)
)
)
)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment