Skip to content

Instantly share code, notes, and snippets.

@chnlkw
Created September 30, 2022 15:26
Show Gist options
  • Save chnlkw/3627a61df0b5e7e0fc266430ef380f51 to your computer and use it in GitHub Desktop.
Save chnlkw/3627a61df0b5e7e0fc266430ef380f51 to your computer and use it in GitHub Desktop.
simple Graph Execution Plan, with evaluator and type inference
object ExamplePlan extends App {
sealed trait Plan
case class ScanNode(label: String, filter: Map[String, Plan], nodeVar: String) extends Plan
case class GetEdge(label: String, filter: Map[String, Plan], srcRef: String, edgeVar: String, dstVar: String, prev: Plan) extends Plan
case class GetNodeProp(label: String, nodeVar: String, prev: Plan) extends Plan
case class Projection(columns: Map[String, Plan], prev: Plan) extends Plan
case class GetField(fieldName: String, x: Plan) extends Plan
case class Ref(x: String) extends Plan
case class VString(value: String) extends Plan
trait Graph {
type NodeKey = Int
type NodeValue = Map[String, Any]
type EdgeValue = Map[String, Any]
val v: Map[String, Map[NodeKey, NodeValue]]
val e: Map[String, Map[NodeKey, List[(NodeKey, EdgeValue)]]]
def nodes(label: String): Seq[(NodeKey, NodeValue)] = v(label).toSeq
def node(label: String, key: NodeKey): Option[NodeValue] = v(label).get(key)
def edges(label: String): Seq[(NodeKey, EdgeValue, NodeKey)] =
for {
(src, edges) <- e(label).toSeq
(dst, e) <- edges
} yield (src, e, dst)
def neighbours(label: String, src: NodeKey): Seq[(NodeKey, EdgeValue)] = e(label)(src)
}
val myGraph: Graph = new Graph {
val v: Map[String, Map[NodeKey, NodeValue]] = Map(
"Person" -> Map(
1 -> Map("name" -> "Alice"),
2 -> Map("name" -> "Bob"),
3 -> Map("name" -> "Charlie"),
4 -> Map("name" -> "David"),
)
)
val e: Map[String, Map[NodeKey, List[(NodeKey, EdgeValue)]]] = Map(
"Friend" -> Map(
1 -> List(2 -> Map()),
2 -> List(3 -> Map(), 4 -> Map()),
)
)
}
def evaluator(g: Graph, p: Plan, context: Map[String, Any]): Any = p match {
case ScanNode(label, filter, nodeVar) =>
g.nodes(label)
.filter { case (nodeKey, nodeValue) =>
val k: Seq[Boolean] = filter.toSeq.map { case (propName, propValue) => nodeValue(propName) == evaluator(g, propValue, context) }
k.forall(x => x)
}
.map { case (nodeKey, nodeValue) => Map(nodeVar -> (nodeValue + ("id" -> nodeKey))) }
case GetEdge(label: String, filter: Map[String, Plan], srcRef: String, edgeVar: String, dstVar: String, prev: Plan) =>
evaluator(g, prev, context).asInstanceOf[Seq[Map[String, Any]]].flatMap(
(row: Map[String, Any]) => {
val srcKey = row(srcRef).asInstanceOf[Map[String, Any]]("id").asInstanceOf[g.NodeKey]
g.neighbours(label, srcKey).map { case (dstKey, edgeValue) =>
row + (dstVar -> Map("id" -> dstKey)) + (edgeVar -> edgeValue)
}
}
)
case GetNodeProp(label: String, nodeVar: String, prev: Plan) =>
evaluator(g, prev, context).asInstanceOf[Seq[Map[String, Any]]].map(
(row: Map[String, Any]) => {
val nodeWithId = row(nodeVar).asInstanceOf[Map[String, Any]]
val nodeKey = nodeWithId("id").asInstanceOf[g.NodeKey]
val nodeValue = g.node(label, nodeKey).get
row + (nodeVar -> (nodeWithId ++ nodeValue))
}
)
case Projection(columns: Map[String, Plan], prev: Plan) =>
evaluator(g, prev, context).asInstanceOf[Seq[Map[String, Any]]].map(
(row: Map[String, Any]) => {
val newCtx = context ++ row
columns.map { case (colName, expr) => colName -> evaluator(g, expr, newCtx) }
}
)
case GetField(fieldName: String, x: Plan) => evaluator(g, x, context).asInstanceOf[Map[String, Any]](fieldName)
case Ref(x: String) => context(x)
case VString(value: String) => value
}
val getTwoHopNeighbourCypher: String = "MATCH (a:Person {name:'Alice'})-[e*2:Friend]->(b:Person) RETURN b.name"
val p1 = ScanNode(label = "Person", filter = Map("name" -> VString("Alice")), nodeVar = "a")
val p2 = GetEdge(label = "Friend", filter = Map(), srcRef = "a", edgeVar = "e", dstVar = "b1", prev = p1)
val p3 = GetEdge(label = "Friend", filter = Map(), srcRef = "b1", edgeVar = "e2", dstVar = "b", prev = p2)
val p4 = GetNodeProp(label = "Person", nodeVar = "b", prev = p3)
val p5 = Projection(columns = Map("twoHopFriendName" -> GetField("name", Ref("b"))), prev = p4)
println(evaluator(myGraph, p1, Map()))
println(evaluator(myGraph, p2, Map()))
println(evaluator(myGraph, p3, Map()))
println(evaluator(myGraph, p4, Map()))
println(evaluator(myGraph, p5, Map()))
sealed trait Ty
case class Record(fields: Map[String, Ty]) extends Ty
case object TString extends Ty
case object TInt extends Ty
case object TBool extends Ty
case class Table(row: Ty) extends Ty
trait GraphSchema {
val nodeTypes: Map[String, (Ty, Map[String, Ty])] // label -> (keyType, propTypeMap)
val edgeTypes: Map[String, (String, String, Map[String, Ty])] // edgeLabel -> (srcLabel, srcType, dstLabel, dstType, edgePropTypeMap)
}
val graphSchema: GraphSchema = new GraphSchema {
override val nodeTypes: Map[String, (Ty, Map[String, Ty])] = Map("Person" -> (TInt, Map("name" -> TString)))
override val edgeTypes: Map[String, (String, String, Map[String, Ty])] = Map("Friend" -> ("Person", "Person", Map()))
}
def inferType(plan: Plan, graphSchema: GraphSchema, ctx: Map[String, Ty]): Ty = plan match {
case Ref(x) => ctx.getOrElse(x, throw Exception(s"variable not found $x"))
case VString(value) => TString
case ScanNode(label, filter, nodeVar) => Table(Record(Map(
nodeVar -> {
val (k, v) = graphSchema.nodeTypes(label)
Record(v + ("id" -> k))
}
)))
case Projection(columns, prev) =>
val prevTy: Ty = inferType(prev, graphSchema, ctx)
val Table(Record(prevCols: Map[String, Ty])) = prevTy
val newCtx = ctx ++ prevCols
Table(Record(columns.map { case (colName, itemPlan) =>
colName -> inferType(itemPlan, graphSchema, newCtx)
}))
case GetField(fieldName, x) => inferType(x, graphSchema, ctx) match {
case Record(fields) => fields.getOrElse(fieldName, throw Exception(s"field not found $x"))
case _ => throw Exception("not record type")
}
case GetEdge(label, filter, srcRef, edgeVar, dstVar, prev) =>
val Table(Record(prevCols)) = inferType(prev, graphSchema, ctx)
val (srcLabel, dstLabel, edgePropTy) = graphSchema.edgeTypes(label)
val (dstKeyTy, dstPropTy) = graphSchema.nodeTypes(dstLabel)
Table(Record(prevCols + (edgeVar -> Record(edgePropTy)) + (dstVar -> Record(Map("id" -> dstKeyTy)))))
case GetNodeProp(label, nodeVar, prev) =>
val Table(Record(prevCols)) = inferType(prev, graphSchema, ctx)
val nodePropTy = graphSchema.nodeTypes(label)._2
val nodeVarType = prevCols(nodeVar) match {
case Record(fields) => Record(fields ++ nodePropTy)
}
Table(Record(prevCols + (nodeVar -> nodeVarType)))
}
val t1 = inferType(p1, graphSchema, Map());
val t2 = inferType(p2, graphSchema, Map());
val t3 = inferType(p3, graphSchema, Map());
val t4 = inferType(p4, graphSchema, Map());
val t5 = inferType(p5, graphSchema, Map());
println(t1)
println(t2)
println(t3)
println(t4)
println(t5)
// val t1 = Table(Record(Map("a" -> Record(Map("name" -> TString, "id" -> TInt)))))
// val t5 = Table(Record(Map("towHopFriendName" -> TString)))
case class LetRec(v: String, exp: Plan, next: Plan) extends Plan
case class Lambda(x: String, v: Plan) extends Plan
case class Apply(f: Plan, v: Plan) extends Plan
case class VInt(n: Int) extends Plan
case class VBool(value: Boolean) extends Plan
case class If(cond: Plan, trueBody: Plan, falseBody: Plan) extends Plan
case class PrimOp(op: String, args: List[Plan]) extends Plan
case class EmptyTable() extends Plan
case class ConcatTable(l: Plan, r: Plan) extends Plan
extension (p: Plan) {
def <(r: Plan): Plan = PrimOp("<", List(p, r))
def >(r: Plan): Plan = PrimOp(">", List(p, r))
def +(r: Plan): Plan = PrimOp("+", List(p, r))
def apply(r: Plan): Plan = Apply(p, r)
}
val recursivePlan = LetRec(
"func",
Lambda("input", Lambda("hop",
If(Ref("hop") < VInt(10),
{ // then concat(input, input.getEdge.getEdge)
val a = Ref("input")
val b = GetEdge(label = "Friend", filter = Map("gender" -> VBool(true)), srcRef = "a", edgeVar = "e", dstVar = "b", prev = a)
val c = GetEdge(label = "Friend", filter = Map("gender" -> VBool(false)), srcRef = "b", edgeVar = "f", dstVar = "c", prev = b)
val next = Projection(columns = Map("a" -> Ref("c")), prev = c)
ConcatTable(
Ref("input"),
Ref("func").apply(next).apply(Ref("hop") + VInt(1))
)
}
,
{ // else
EmptyTable()
}
)
)),
Ref("func").apply(ScanNode("Person", Map("name" -> VString("Alice")), "a")).apply(VInt(0))
)
println(recursivePlan)
case class LetRec2(v: String, ty: Ty, exp: Plan, next: Plan) extends Plan
object RecursiveScheme {
case class Fix[F[_]](unfix: F[Fix[F]])
sealed trait PlanF[A]
case class ScanNode[A](label: String, filter: Map[String, A], nodeVar: String) extends PlanF[A]
case class Projection[A](columns: Map[String, A], prev: Plan) extends PlanF[A]
type Plan = Fix[PlanF]
case class TyPlanF[A](ty: Ty, plan: PlanF[A])
type TyPlan = Fix[TyPlanF]
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment