Skip to content

Instantly share code, notes, and snippets.

@cloud-fan
Last active August 29, 2015 14:11
Show Gist options
  • Save cloud-fan/b835d2becf93086c6a7f to your computer and use it in GitHub Desktop.
Save cloud-fan/b835d2becf93086c6a7f to your computer and use it in GitHub Desktop.
scala dijkstra
import scala.collection.mutable
import scala.collection.immutable.HashMap
class Graph(data: Seq[(String, String, Double)]) {
case class Vertex(label: String, connection: Seq[Edge])
class Edge(targetLabel: String, val weight: Double) {
lazy val target = adjacencyList(targetLabel)
}
private val adjacencyList: HashMap[String, Vertex] =
data.groupBy(_._1).map { case (src, edges) =>
src -> Vertex(src, edges.map { case (_, target, weight) =>
new Edge(target, weight)
})
}(collection.breakOut)
def getVertex(label: String) = adjacencyList(label)
val vertexesCount = adjacencyList.size
}
object Graph {
def apply(edges: Seq[(String, String, Double)]) = new Graph(
clean(edges).force
)
def undirected(edges: Seq[(String, String, Double)]) = new Graph(
clean(edges).flatMap { case (from, to, weight) =>
Seq((from, to, weight), (to, from, weight))
}.force
)
private def clean(edges: Seq[(String, String, Double)]) = {
val acceptedEdges = mutable.HashSet.empty[(String, String)]
edges.view.filter { case (from, to, weight) =>
if (acceptedEdges(from -> to) || from == to) false
else {
acceptedEdges += from -> to
true
}
}
}
}
/**
* Created by cloud0fan on 12/15/14.
*/
object ShortestPath extends App {
def buildPath(map: mutable.HashMap[String, String], endNode: String) = {
val path = mutable.ListBuffer(endNode)
while (map.contains(path.head)) {
map(path.head) +=: path
}
path.toList
}
def dijkstra(graph: Graph, sourceLabel: String) = {
import graph._
val source = getVertex(sourceLabel)
val queue = mutable.PriorityQueue.empty(
Ordering.fromLessThan[(Vertex, Option[String], Double)](_._3 > _._3)
)
val result = mutable.HashMap.empty[String, Double]
val mapToPrevious = mutable.HashMap.empty[String, String]
def isToDo(vertex: Vertex) = !result.contains(vertex.label)
def augment(node: Vertex, pre: Option[String], length: Double): Unit = {
result += node.label -> length
pre.foreach(pre => mapToPrevious += node.label -> pre)
for (edge <- node.connection if isToDo(edge.target)) {
queue.enqueue(Tuple3(edge.target, Some(node.label), length + edge.weight))
}
}
queue += Tuple3(source, None, 0)
while (queue.nonEmpty) {
val (node, pre, length) = queue.dequeue()
if (isToDo(node)) augment(node, pre, length)
}
result.map { case (node, length) =>
length -> buildPath(mapToPrevious, node)
}.toSeq
}
def spfa(graph: Graph, sourceLabel: String) = {
import graph._
val result = mutable.HashMap.empty[String, Info]
val queue = mutable.Queue.empty[Vertex]
val mapToPrevious = mutable.HashMap.empty[String, String]
var negativeCycleDetected = false
def enqueueSLF(node: Vertex, length: Double): Unit = {
if (queue.nonEmpty && length < result(queue.front.label).length)
node +=: queue
else
queue += node
}
class Info(var isInQueue: Boolean, var enqueueCount: Int, var length: Double) {
def update(node: Vertex, newLength: Double): Unit = {
length = newLength
if (!isInQueue && enqueueCount < vertexesCount) {
isInQueue = true
enqueueCount += 1
enqueueSLF(node, length)
}
if (!isInQueue) negativeCycleDetected = true
}
}
val source = getVertex(sourceLabel)
addNewResult(source, 0)
def addNewResult(node: Vertex, length: Double) = {
val info = new Info(true, 1, length)
result += node.label -> info
enqueueSLF(node, length)
info
}
def relax(node: Vertex, length: Double): Unit = {
for (edge <- node.connection if !negativeCycleDetected) {
val neighbour = edge.target
val newLength = length + edge.weight
val info = result.getOrElse(neighbour.label, {
mapToPrevious += neighbour.label -> node.label
addNewResult(neighbour, newLength)
})
if (newLength < info.length) {
mapToPrevious += neighbour.label -> node.label
info.update(neighbour, newLength)
}
}
}
while (!negativeCycleDetected && queue.nonEmpty) {
val node = queue.dequeue()
val nodeInfo = result(node.label)
nodeInfo.isInQueue = false
relax(node, nodeInfo.length)
}
if (negativeCycleDetected) Nil else
result.iterator.map { case (node, info) =>
info.length -> buildPath(mapToPrevious, node)
}.toSeq
}
val graph = Graph.undirected(Seq(
("A", "B", 6.0),
("A", "C", 3.0),
("B", "C", 2.0),
("B", "D", 5.0),
("C", "D", 3.0),
("C", "E", 4.0),
("D", "E", 2.0),
("E", "F", 5.0),
("D", "F", 3.0)
))
dijkstra(graph, "A").foreach(e => println(e._1 + ": " + e._2.mkString(" -> ")))
println("--------")
spfa(graph, "A").foreach(e => println(e._1 + ": " + e._2.mkString(" -> ")))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment