Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
最小全域木。プリム法とクラスカル法。UnionFind。
import scala.annotation.tailrec
object Main {
def main(args: Array[String]) {
println("prim")
prim(sampleGraph).sortBy(_.weight).foreach(println)
println("kruskal")
kruskal(sampleGraph).sortBy(_.weight).foreach(println)
}
// クラスカル法
// 重みの小さいEdgeから決定していく。
// 閉炉になる場合は除外する。UnionFindを使って閉炉になるか判定。
def kruskal(g: Graph): Seq[Edge] = {
val sortedEdges = g.es.values.toSeq.flatten.sortBy(_.weight)
val uf = UnionFind[Vertex](g.vertices)
sortedEdges.collect { case e if !uf.isSame(e.from, e.to) =>
uf.unify(e.from, e.to)
e
}
}
// プリム法
// スタート地点からゴールまで、重みが小さいものを常に選択し続ける。
def prim(g :Graph): Seq[Edge] = {
val pq = scala.collection.mutable.PriorityQueue[Edge]()
pq.enqueue(Edge(Vertex("dummy"), Vertex("A"), 0))
@tailrec
def f(r: Seq[Edge] = Nil, vd: Seq[Vertex] = Nil): Seq[Edge] = {
if (pq.isEmpty) r
else {
val e = pq.dequeue()
if (vd.contains(e.to)) f(r, vd)
else {
g.es.get(e.to).map { _.foreach { ne =>
if (ne.to != e.from) pq.enqueue(ne)
} }
f(e +: r, e.to +: vd)
}
}
}
f().init
}
case class Graph(es: Map[Vertex, Seq[Edge]]) {
val vertices: Seq[Vertex] = es.toSeq.flatMap { case (v, es) => v +: es.map(_.to) }
}
case class Vertex(name: String)
case class Edge(from: Vertex, to: Vertex, weight: Int) extends Ordered[Edge] {
def compare(e: Edge): Int = e.weight - weight
}
type Path = Map[Vertex, Edge]
def sampleGraph: Graph = Graph(
Map(
Vertex("A") -> Seq(
Edge(Vertex("A"), Vertex("B"), 7),
Edge(Vertex("A"), Vertex("D"), 5)
),
Vertex("B") -> Seq(
Edge(Vertex("B"), Vertex("A"), 7),
Edge(Vertex("B"), Vertex("C"), 8),
Edge(Vertex("B"), Vertex("D"), 9),
Edge(Vertex("B"), Vertex("E"), 7)
),
Vertex("C") -> Seq(
Edge(Vertex("C"), Vertex("B"), 8),
Edge(Vertex("C"), Vertex("E"), 5)
),
Vertex("D") -> Seq(
Edge(Vertex("D"), Vertex("A"), 5),
Edge(Vertex("D"), Vertex("B"), 9),
Edge(Vertex("D"), Vertex("E"), 15),
Edge(Vertex("D"), Vertex("F"), 6)
),
Vertex("E") -> Seq(
Edge(Vertex("E"), Vertex("B"), 7),
Edge(Vertex("E"), Vertex("C"), 5),
Edge(Vertex("E"), Vertex("D"), 15),
Edge(Vertex("E"), Vertex("F"), 8),
Edge(Vertex("E"), Vertex("G"), 9)
),
Vertex("F") -> Seq(
Edge(Vertex("F"), Vertex("D"), 6),
Edge(Vertex("F"), Vertex("E"), 8),
Edge(Vertex("F"), Vertex("G"), 11)
),
Vertex("G") -> Seq(
Edge(Vertex("G"), Vertex("E"), 9),
Edge(Vertex("G"), Vertex("F"), 11)
)
)
)
/*
UnionFind
グループを作ることができ、判定することができる。
ex)
根の状態 A:A B:B C:C D:D E:E (Aの根はAを指す)
unify(A,B) -> A:B B:B C:C D:D E:E -> Aの根がBに。
unify(B,C) -> A:B B:C C:C D:D E:E -> Bの根がCに。
isSame(A,C) -> Aの根はBでBの根はC。Cの根はC。ゆえに同グループ。
A>B>Cのように根が深くなると探索数が増えるので、A>Cのように根を付け替える。
usage)
case class Eee(i: Int)
val edges: Vector[Eee] = Vector[Eee](Eee(0), Eee(1), Eee(2), Eee(3), Eee(4))
val uf = UnionFind[Eee](edges)
uf.unify(Eee(1), Eee(2))
uf.unify(Eee(2), Eee(3))
println(uf.isSame(Eee(1), Eee(2)))
println(uf.isSame(Eee(1), Eee(3)))
println(uf.isSame(Eee(2), Eee(3)))
println(uf.isSame(Eee(1), Eee(4)))
*/
import scala.collection.mutable
case class UnionFind[T](n: Seq[T]) {
private[this] val par = n.foldLeft(mutable.Map[T, T]()) { case (m, i) => m + (i -> i) }
@tailrec
private[this] def find(a: T, r: Seq[T] = Nil): T = {
if (par(a) == a) {
r.foreach { par.update(_, a) }
a
} else {
val pa = par(a)
find(pa, pa +: r)
}
}
def unify(a: T, b: T) { if (find(a) != find(b)) par.update(find(a), find(b)) }
def isSame(a: T, b: T): Boolean = find(a) == find(b)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment