Skip to content

Instantly share code, notes, and snippets.

@mahmoudhanafy
Created August 28, 2017 14:58
Show Gist options
  • Save mahmoudhanafy/ba0b7cea1382df77470ba6a10b8d4b2a to your computer and use it in GitHub Desktop.
Save mahmoudhanafy/ba0b7cea1382df77470ba6a10b8d4b2a to your computer and use it in GitHub Desktop.
package test
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import scala.collection.mutable
class DisjointSet() extends Serializable {
private val parentMap = mutable.Map[Int, Int]()
def findParent(id: Int): Int = {
val nodeParentOpt = parentMap.get(id)
nodeParentOpt match {
case Some(nodeParent) if nodeParent == id =>
id
case Some(nodeParent) =>
val setParent = findParent(nodeParent)
parentMap.put(id, setParent)
setParent
case None =>
parentMap.put(id, id)
id
}
}
def isSameSet(a: Int, b: Int): Boolean = {
findParent(a) == findParent(b)
}
def union(a: Int, b: Int): Unit = {
val aParent = findParent(a)
val bParent = findParent(b)
parentMap.put(aParent, bParent)
}
def listSets(): Map[Int, Seq[Int]] = {
parentMap.keys.map(node => node -> findParent(node)).groupBy(_._2).mapValues(_.map(_._1).toSeq)
}
def isEmpty() = {
parentMap.isEmpty
}
def +(other: DisjointSet): DisjointSet = {
if (isEmpty()) {
other
} else if (other.isEmpty()) {
this
} else {
val combinedSet = new DisjointSet
def appendToCombinedSet(set: DisjointSet): Unit = {
set.listSets().foreach { case (setParent, setElements) =>
setElements.foreach(node => combinedSet.union(setParent, node))
}
}
appendToCombinedSet(this)
appendToCombinedSet(other)
combinedSet
}
}
}
object Main {
def main(args: Array[String]): Unit = {
val sparkSession = SparkSession.builder.master("local").appName("mostafa").getOrCreate()
val inputSeq = Seq(
(1, 2), (2, 3), (3, 4),
(10, 20), (20, 50), (10, 30), (10, 40),
(300, 400), (300, 100),
(1000, 2000)
)
val inputDS = sparkSession.sparkContext.parallelize(inputSeq).repartition(100)
val disjointSets: RDD[DisjointSet] =
inputDS.mapPartitions { relations =>
val disjointSet = new DisjointSet()
relations.foreach { case (a, b) =>
if (!disjointSet.isSameSet(a, b))
disjointSet.union(a, b)
}
Iterator(disjointSet) // have to return iterator at mapPartitions
}
val finalSet = disjointSets.reduce(_ + _)
val allSets = finalSet.listSets()
println(s"${allSets.mkString("\n")}")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment