Skip to content

Instantly share code, notes, and snippets.

@blakewrege
Last active April 27, 2016 22:12
Show Gist options
  • Save blakewrege/f25fbec2b6873015d6bebec0014431b1 to your computer and use it in GitHub Desktop.
Save blakewrege/f25fbec2b6873015d6bebec0014431b1 to your computer and use it in GitHub Desktop.
Parallel Prims Algorithm for Spark
50,70
1,8,10
2,17,10
2,34,17
3,35,16
3,38,1
4,5,6
4,22,8
4,25,3
5,12,13
5,19,2
5,21,20
5,50,8
6,31,2
6,32,8
6,38,16
7,20,1
7,44,11
8,20,3
8,32,2
8,41,6
9,23,9
9,35,10
10,23,13
10,27,12
11,24,5
11,26,14
11,30,14
11,37,15
11,38,14
13,34,5
13,41,18
14,23,15
14,26,7
14,44,9
15,27,19
15,33,20
16,34,11
16,46,13
17,39,16
17,47,9
18,29,19
18,32,14
20,43,10
21,34,1
22,40,5
23,28,20
23,32,13
23,46,6
23,48,10
24,36,9
25,30,16
25,32,17
25,36,8
25,42,2
26,36,16
29,49,15
32,34,4
32,35,11
32,47,10
33,35,17
33,45,18
34,49,16
35,36,3
37,46,6
38,45,15
38,49,10
39,50,5
40,45,19
42,43,15
44,50,16
import org.apache.log4j.Level
import org.apache.log4j.Logger
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.graphx.Edge
import org.apache.spark.graphx.EdgeTriplet
import org.apache.spark.graphx.Graph
import org.apache.spark.graphx.Graph.graphToGraphOps
import org.apache.spark.graphx.VertexId
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDD.rddToPairRDDFunctions
object ParallelPrims {
Logger.getLogger("org").setLevel(Level.OFF)
Logger.getLogger("akka").setLevel(Level.OFF)
var total = 0
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("Parallel Prims")
val sc = new SparkContext(conf)
val logFile = "/root/cluster-computing/demos/graph-generator/NodeData.txt"
val logData = sc.textFile(logFile, 2).cache()
// Splitting off header node
val headerAndRows = logData.map(line => line.split(",").map(_.trim))
val header = headerAndRows.first
val data = headerAndRows.filter(_(0) != header(0))
// Parse number of Nodes and Edges from header
val numNodes = header(0).toInt
val numEdges = header(1).toInt
val vertexArray = new Array[(Long, String)](numNodes)
var edgeArray = new Array[Edge[Int]](numEdges)
// Create vertex array
var count = 0
for (count <- 0 to numNodes - 1) {
vertexArray(count) = (count.toLong + 1, ("v" + (count + 1)).toString())
}
count = 0
val rrdarr = data.take(data.count.toInt)
// Create edge array
for (count <- 0 to (numEdges - 1)) {
val line = rrdarr(count)
val cols = line.toList
val edge = Edge(cols(0).toLong, cols(1).toLong, cols(2).toInt)
edgeArray(count) = Edge(cols(0).toLong, cols(1).toLong, cols(2).toInt)
}
// Creating graphx graph
val vertexRDD: RDD[(Long, (String))] = sc.parallelize(vertexArray)
val edgeRDD: RDD[Edge[Int]] = sc.parallelize(edgeArray)
var graph: Graph[String, Int] = Graph(vertexRDD, edgeRDD)
// graph.triplets.take(6).foreach(println)
// just empty RDD for MST
var MST = sc.parallelize(Array[EdgeTriplet[String, Int]]())
// pick random vertex from graph
var Vt: RDD[VertexId] = sc.parallelize(Array(graph.pickRandomVertex))
// do until all vertices is in Vt set
val vcount = graph.vertices.count
while (Vt.count < vcount) {
// rdd to make inner joins
val hVt = Vt.map(x => (x, x))
// add key to make inner join
val bySrc = graph.triplets.map(triplet => (triplet.srcId, triplet))
// add key to make inner join
val byDst = graph.triplets.map(triplet => (triplet.dstId, triplet))
// all triplet where source vertex is in Vt
val bySrcJoined = bySrc.join(hVt).map(_._2._1)
// all triplet where destinaiton vertex is in Vt
val byDstJoined = byDst.join(hVt).map(_._2._1)
// sum previous two rdds and substract all triplets where both source and destination vertex in Vt
val candidates = bySrcJoined.union(byDstJoined).subtract(byDstJoined.intersection(bySrcJoined))
// find triplet with least weight
val triplet = candidates.sortBy(triplet => triplet.attr).first
// add triplet to MST
MST = MST.union(sc.parallelize(Array(triplet)))
// find out whether we should add source or destinaiton vertex to Vt
if (!Vt.filter(x => x == triplet.srcId).isEmpty) {
Vt = Vt.union(sc.parallelize(Array(triplet.dstId)))
} else {
Vt = Vt.union(sc.parallelize(Array(triplet.srcId)))
}
}
// final minimum spanning tree
MST.collect.foreach {
p =>
println(p.srcId + "<--->" + p.dstId + " " + (p.attr))
}
val total = MST.map{case(a) =>
a.attr.toDouble
}.collect
println(total.reduceLeft{ _ + _ })
}
}
@blakewrege
Copy link
Author

Currently having issues with the cluster doing this:
2016-04-27 18_08_32-parallel prims - details for job 277

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment