Last active
August 29, 2015 14:11
-
-
Save cloud-fan/b835d2becf93086c6a7f to your computer and use it in GitHub Desktop.
scala dijkstra
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.collection.mutable | |
import scala.collection.immutable.HashMap | |
class Graph(data: Seq[(String, String, Double)]) { | |
case class Vertex(label: String, connection: Seq[Edge]) | |
class Edge(targetLabel: String, val weight: Double) { | |
lazy val target = adjacencyList(targetLabel) | |
} | |
private val adjacencyList: HashMap[String, Vertex] = | |
data.groupBy(_._1).map { case (src, edges) => | |
src -> Vertex(src, edges.map { case (_, target, weight) => | |
new Edge(target, weight) | |
}) | |
}(collection.breakOut) | |
def getVertex(label: String) = adjacencyList(label) | |
val vertexesCount = adjacencyList.size | |
} | |
object Graph { | |
def apply(edges: Seq[(String, String, Double)]) = new Graph( | |
clean(edges).force | |
) | |
def undirected(edges: Seq[(String, String, Double)]) = new Graph( | |
clean(edges).flatMap { case (from, to, weight) => | |
Seq((from, to, weight), (to, from, weight)) | |
}.force | |
) | |
private def clean(edges: Seq[(String, String, Double)]) = { | |
val acceptedEdges = mutable.HashSet.empty[(String, String)] | |
edges.view.filter { case (from, to, weight) => | |
if (acceptedEdges(from -> to) || from == to) false | |
else { | |
acceptedEdges += from -> to | |
true | |
} | |
} | |
} | |
} | |
/** | |
* Created by cloud0fan on 12/15/14. | |
*/ | |
object ShortestPath extends App { | |
def buildPath(map: mutable.HashMap[String, String], endNode: String) = { | |
val path = mutable.ListBuffer(endNode) | |
while (map.contains(path.head)) { | |
map(path.head) +=: path | |
} | |
path.toList | |
} | |
def dijkstra(graph: Graph, sourceLabel: String) = { | |
import graph._ | |
val source = getVertex(sourceLabel) | |
val queue = mutable.PriorityQueue.empty( | |
Ordering.fromLessThan[(Vertex, Option[String], Double)](_._3 > _._3) | |
) | |
val result = mutable.HashMap.empty[String, Double] | |
val mapToPrevious = mutable.HashMap.empty[String, String] | |
def isToDo(vertex: Vertex) = !result.contains(vertex.label) | |
def augment(node: Vertex, pre: Option[String], length: Double): Unit = { | |
result += node.label -> length | |
pre.foreach(pre => mapToPrevious += node.label -> pre) | |
for (edge <- node.connection if isToDo(edge.target)) { | |
queue.enqueue(Tuple3(edge.target, Some(node.label), length + edge.weight)) | |
} | |
} | |
queue += Tuple3(source, None, 0) | |
while (queue.nonEmpty) { | |
val (node, pre, length) = queue.dequeue() | |
if (isToDo(node)) augment(node, pre, length) | |
} | |
result.map { case (node, length) => | |
length -> buildPath(mapToPrevious, node) | |
}.toSeq | |
} | |
def spfa(graph: Graph, sourceLabel: String) = { | |
import graph._ | |
val result = mutable.HashMap.empty[String, Info] | |
val queue = mutable.Queue.empty[Vertex] | |
val mapToPrevious = mutable.HashMap.empty[String, String] | |
var negativeCycleDetected = false | |
def enqueueSLF(node: Vertex, length: Double): Unit = { | |
if (queue.nonEmpty && length < result(queue.front.label).length) | |
node +=: queue | |
else | |
queue += node | |
} | |
class Info(var isInQueue: Boolean, var enqueueCount: Int, var length: Double) { | |
def update(node: Vertex, newLength: Double): Unit = { | |
length = newLength | |
if (!isInQueue && enqueueCount < vertexesCount) { | |
isInQueue = true | |
enqueueCount += 1 | |
enqueueSLF(node, length) | |
} | |
if (!isInQueue) negativeCycleDetected = true | |
} | |
} | |
val source = getVertex(sourceLabel) | |
addNewResult(source, 0) | |
def addNewResult(node: Vertex, length: Double) = { | |
val info = new Info(true, 1, length) | |
result += node.label -> info | |
enqueueSLF(node, length) | |
info | |
} | |
def relax(node: Vertex, length: Double): Unit = { | |
for (edge <- node.connection if !negativeCycleDetected) { | |
val neighbour = edge.target | |
val newLength = length + edge.weight | |
val info = result.getOrElse(neighbour.label, { | |
mapToPrevious += neighbour.label -> node.label | |
addNewResult(neighbour, newLength) | |
}) | |
if (newLength < info.length) { | |
mapToPrevious += neighbour.label -> node.label | |
info.update(neighbour, newLength) | |
} | |
} | |
} | |
while (!negativeCycleDetected && queue.nonEmpty) { | |
val node = queue.dequeue() | |
val nodeInfo = result(node.label) | |
nodeInfo.isInQueue = false | |
relax(node, nodeInfo.length) | |
} | |
if (negativeCycleDetected) Nil else | |
result.iterator.map { case (node, info) => | |
info.length -> buildPath(mapToPrevious, node) | |
}.toSeq | |
} | |
val graph = Graph.undirected(Seq( | |
("A", "B", 6.0), | |
("A", "C", 3.0), | |
("B", "C", 2.0), | |
("B", "D", 5.0), | |
("C", "D", 3.0), | |
("C", "E", 4.0), | |
("D", "E", 2.0), | |
("E", "F", 5.0), | |
("D", "F", 3.0) | |
)) | |
dijkstra(graph, "A").foreach(e => println(e._1 + ": " + e._2.mkString(" -> "))) | |
println("--------") | |
spfa(graph, "A").foreach(e => println(e._1 + ": " + e._2.mkString(" -> "))) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment