Skip to content

Instantly share code, notes, and snippets.

@ankurdave
Last active February 20, 2017 01:09
Show Gist options
  • Save ankurdave/4a17596669b36be06100 to your computer and use it in GitHub Desktop.
Save ankurdave/4a17596669b36be06100 to your computer and use it in GitHub Desktop.
Spark code to find distances to reachable source vertices using GraphX
// Spark code to find distances to reachable source vertices using GraphX.
// See http://apache-spark-user-list.1001560.n3.nabble.com/counting-degrees-graphx-td6370.html
import org.apache.spark.graphx._
import scala.collection.immutable.Map
val vertexArray = Array(
(1L,("101","x")),
(2L,("102","y")),
(3L,("103","y")),
(4L,("104","y")),
(5L,("105","y")),
(6L,("106","x")),
(7L,("107","x")),
(8L,("108","y")))
val edgeArray = Array(
Edge(1L,2L,1),
Edge(1L,3L,2),
Edge(3L,4L,3),
Edge(3L,5L,4),
Edge(6L,5L,5),
Edge(7L,8L,6))
val g = Graph(sc.parallelize(vertexArray), sc.parallelize(edgeArray))
type DistanceMap = Map[VertexId, Int]
// For each vertex, initialize a distance map recording the distance from each reachable
// source. Sources (vertices with zero in-degree) are initialized to Map(selfVid -> 0), and all
// other vertices are initialized to an empty map.
val initDists: Graph[DistanceMap, Int] = g.outerJoinVertices(g.inDegrees) {
(vid, prevAttr, inDeg) => if (inDeg.getOrElse(0) == 0) Map(vid -> 0) else Map()
}
// Propagate distances along each edge from vertex a to vertex b.
def sendMsg(edge: EdgeTriplet[DistanceMap, Int]): Iterator[(VertexId, DistanceMap)] = {
edge.dstAttr // ensure that GraphX replicates both attributes to work around SPARK-3936
// If vertex a knows about a shorter path to a source, increment the distance and include it in
// updatedDists.
val updatedDists = edge.srcAttr.filter {
case (source, dist) =>
val existingDist = edge.dstAttr.getOrElse(source, Int.MaxValue)
existingDist > dist + 1
}.mapValues(_ + 1).map(identity)
// Send the updated distances to vertex b, if there are any.
if (updatedDists.nonEmpty) {
Iterator((edge.dstId, updatedDists))
} else {
Iterator.empty
}
}
// Merge distance maps by taking the shorter distance for each source.
def mergeMsg(a: DistanceMap, b: DistanceMap): DistanceMap = {
(a.keySet ++ b.keySet).map(source =>
(source, math.min(a.getOrElse(source, Int.MaxValue), b.getOrElse(source, Int.MaxValue)))).toMap
}
// Apply distance maps by overwriting existing values with new values.
def vprog(vid: VertexId, curDists: DistanceMap, newDists: DistanceMap): DistanceMap = {
curDists ++ newDists
}
// For each vertex, find the distance to all reachable sources by running the above steps
// iteratively until no more messages are sent.
val dists = initDists.pregel[DistanceMap](Map())(vprog, sendMsg, mergeMsg)
// Print the result for each vertex.
dists.vertices.collect.sortBy(_._1).foreach(println(_))
// (1,Map(1 -> 0))
// (2,Map(1 -> 1))
// (3,Map(1 -> 1))
// (4,Map(1 -> 2))
// (5,Map(6 -> 1, 1 -> 2))
// (6,Map(6 -> 0))
// (7,Map(7 -> 0))
// (8,Map(7 -> 1))
@Chris19920210
Copy link

Hi Ankurdave,
This is Rihan. Sorry to bother you. I'm not very familiar with GraphX api. Your code is pretty useful for my project. But I wanna print out the path from all the nodes to all the roots. Could you please give me a hint? Thank you anyway!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment