Skip to content

Instantly share code, notes, and snippets.

@Topher-the-Geek
Created June 25, 2024 14:34
Show Gist options
  • Save Topher-the-Geek/284790fad0ddb182f996c883b19ccdfc to your computer and use it in GitHub Desktop.
Save Topher-the-Geek/284790fad0ddb182f996c883b19ccdfc to your computer and use it in GitHub Desktop.
// Adapted from https://github.com/sangria-graphql/sangria-circe/blob/e70826e6bb152c65b27e4b51958fc593d9434b40/src/main/scala/sangria/marshalling/circe.scala
// Licensed under Apache 2.0
// This version supports Json as a scalar type
import io.circe._
import sangria.ast
import sangria.marshalling._
import sangria.schema.ScalarType
import sangria.validation.ValueCoercionViolation
object marshalling {
private def coerceUserInput(value: Any): Json = {
value match {
case v: String => Json.fromString(v)
case v: Boolean => Json.fromBoolean(v)
case v: Int => Json.fromInt(v)
case v: Long => Json.fromLong(v)
case v: Float => Json.fromDoubleOrNull(v)
case v: Double => Json.fromDoubleOrNull(v)
case v: BigInt => Json.fromBigInt(v)
case v: BigDecimal => Json.fromBigDecimal(v)
case v: Json => v
case _ => throw new IllegalArgumentException(s"Unsupported scalar value: $value")
}
}
implicit object CirceResultMarshaller extends ResultMarshaller {
type Node = Json
type MapBuilder = ArrayMapBuilder[Node]
def emptyMapNode(keys: Seq[String]) = new ArrayMapBuilder[Node](keys)
def addMapNodeElem(builder: MapBuilder, key: String, value: Node, optional: Boolean): ArrayMapBuilder[Json] =
builder.add(key, value)
def mapNode(builder: MapBuilder): Json = Json.fromFields(builder)
def mapNode(keyValues: Seq[(String, Json)]): Json = Json.fromFields(keyValues)
def arrayNode(values: Vector[Json]): Json = Json.fromValues(values)
def optionalArrayNodeValue(value: Option[Json]): Json = value match {
case Some(v) => v
case None => nullNode
}
def scalarNode(value: Any, typeName: String, info: Set[ScalarValueInfo]): Json = coerceUserInput(value)
def enumNode(value: String, typeName: String): Json = Json.fromString(value)
def nullNode: Json = Json.Null
def renderCompact(node: Json): String = node.noSpaces
def renderPretty(node: Json): String = node.spaces2
}
implicit object CirceMarshallerForType extends ResultMarshallerForType[Json] {
val marshaller: CirceResultMarshaller.type = CirceResultMarshaller
}
implicit object CirceInputUnmarshaller extends InputUnmarshaller[Json] {
def getRootMapValue(node: Json, key: String): Option[Json] = node.asObject.get(key)
def isMapNode(node: Json): Boolean = node.isObject
def getMapValue(node: Json, key: String): Option[Json] = node.asObject.get(key)
def getMapKeys(node: Json): Iterable[String] = node.asObject.get.keys
def isListNode(node: Json): Boolean = node.isArray
def getListValue(node: Json): Vector[Json] = node.asArray.get
def isDefined(node: Json): Boolean = !node.isNull
def getScalarValue(node: Json): Any = {
def invalidScalar = throw new IllegalStateException(s"$node is not a scalar value")
node.fold(
jsonNull = null,
jsonBoolean = identity,
jsonNumber = num => num.toBigInt.orElse(num.toBigDecimal).getOrElse(invalidScalar),
jsonString = identity,
jsonArray = identity,
jsonObject = identity
)
}
def getScalaScalarValue(node: Json): Any = getScalarValue(node)
def isEnumNode(node: Json): Boolean = node.isString
def isScalarNode(node: Json): Boolean =
node.isBoolean || node.isNumber || node.isString
def isVariableNode(node: Json) = false
def getVariableName(node: Json) = throw new IllegalArgumentException("variables are not supported")
def render(node: Json): String = node.noSpaces
}
implicit object circeToInput extends ToInput[Json, Json] {
def toInput(value: Json): (Json, CirceInputUnmarshaller.type) = (value, CirceInputUnmarshaller)
}
implicit object circeFromInput extends FromInput[Json] {
val marshaller: CirceResultMarshaller.type = CirceResultMarshaller
def fromResult(node: marshaller.Node): marshaller.Node = node
}
implicit def circeEncoderToInput[T: Encoder]: ToInput[T, Json] =
(value: T) => implicitly[Encoder[T]].apply(value) -> CirceInputUnmarshaller
implicit def circeDecoderFromInput[T: Decoder]: FromInput[T] =
new FromInput[T] {
val marshaller: CirceResultMarshaller.type = CirceResultMarshaller
def fromResult(node: marshaller.Node): T = implicitly[Decoder[T]].decodeJson(node) match {
case Right(obj) => obj
case Left(error) => throw InputParsingError(Vector(error.getMessage))
}
}
private def coerceInput(value: ast.Value): Json = {
value match {
case ast.NullValue(_, _) => Json.Null
case ast.IntValue(i, _, _) => Json.fromLong(i)
case ast.BigIntValue(i, _, _) => Json.fromBigInt(i)
case ast.FloatValue(f, _, _) => Json.fromDoubleOrNull(f)
case ast.BigDecimalValue(d, _, _) => Json.fromBigDecimal(d)
case ast.BooleanValue(b, _, _) => Json.fromBoolean(b)
case ast.StringValue(s, _, _, _, _) => Json.fromString(s)
case ast.EnumValue(s, _, _) => Json.fromString(s)
case ast.ListValue(values, _, _) => Json.fromValues(values.map(coerceInput))
case ast.ObjectValue(fields, _, _) => Json.fromFields(fields.map(f => f.name -> coerceInput(f.value)))
case _ => throw new IllegalArgumentException(s"Unsupported Json value: $value")
}
}
object JsonCoercionViolation extends ValueCoercionViolation("Json value expected")
implicit val JsonType: ScalarType[Json] =
ScalarType[Json](
name = "Json",
coerceOutput = (value, _) => value,
coerceInput = {
value =>
try {
Right(coerceInput(value))
} catch {
case e: IllegalArgumentException => Left(JsonCoercionViolation)
}
},
coerceUserInput = {
value =>
try {
Right(coerceUserInput(value))
} catch {
case e: IllegalArgumentException => Left(JsonCoercionViolation)
}
}
)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment