Last active
January 27, 2016 19:35
-
-
Save polyglotpiglet/8131ca24392f08fae741 to your computer and use it in GitHub Desktop.
Dijkstra in scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object Dijkstra extends App { | |
val graph = buildExampleGraph() | |
shortestPath(Node[String]("car"), Node[String]("tat"), graph) | |
shortestPath(Node[String]("cat"), Node[String]("tar"), graph) | |
case class Node[T](value: T) | |
class Graph[T] { | |
private val edges = new mutable.HashMap[Node[T], Seq[Node[T]]]() | |
private var nodes = Seq[Node[T]]() | |
def addNode(node: Node[T]) = nodes = node +: nodes | |
def addNodes(newNodes: Seq[Node[T]]) = nodes = newNodes ++ nodes | |
def addEdge(n1: Node[T], n2: Node[T]) = { | |
edges(n1) = n2 +: edges.getOrElse(n1, Seq[Node[T]]()) | |
edges(n2) = n1 +: edges.getOrElse(n2, Seq[Node[T]]()) | |
} | |
def getNodesWithEdgeFrom(node: Node[T]) | |
= edges.getOrElse(node, Seq[Node[T]]()) | |
def getAllNodes = nodes | |
} | |
def buildExampleGraph(): Graph[String] = { | |
val graph = new Graph[String]() | |
val cat = Node[String]("cat") | |
val cab = Node[String]("cab") | |
val car = Node[String]("car") | |
val bar = Node[String]("bar") | |
val tar = Node[String]("tar") | |
val tat = Node[String]("tat") | |
graph.addNodes(Seq(cat, cab, car, bar, tar, tat)) | |
graph.addEdge(cat, cab) | |
graph.addEdge(cat, car) | |
graph.addEdge(cab, car) | |
graph.addEdge(car, bar) | |
graph.addEdge(bar, tar) | |
graph.addEdge(tar, tat) | |
graph | |
} | |
def shortestPath[T](start: Node[T], | |
end: Node[T], | |
graph: Graph[T]): Unit = { | |
// initially all distances are infinity, except start node where distance = 0 | |
val pathAndDistanceFromStart: mutable.Map[Node[T], (Seq[Node[T]], Int)] | |
= collection.mutable.Map(graph.getAllNodes.map { | |
case n: Node[T] if n == start => n -> (Seq(n), 0) | |
case n: Node[T] => n -> (Seq[Node[T]](), Integer.MAX_VALUE) | |
}: _*) | |
// keep track of all unvisited nodes | |
var unvisited = graph.getAllNodes | |
def aux(node: Node[T]): (Seq[Node[T]],Int) = { | |
// if we have visited the target node then we are done | |
if (!unvisited.contains(end)) pathAndDistanceFromStart(end) | |
else { | |
// mark current node as visited | |
unvisited = unvisited.filterNot(_ == node) | |
val linkedNodes = graph.getNodesWithEdgeFrom(node) | |
.filter(unvisited.contains(_)) | |
linkedNodes.foreach { n => | |
pathAndDistanceFromStart(n) | |
= pathAndDistanceFromStart(n) match { | |
case (s, Integer.MAX_VALUE) | |
=> ( pathAndDistanceFromStart(node)._1 :+ n, pathAndDistanceFromStart(node)._2 + 1) | |
case (s, i) | |
=> if (i < pathAndDistanceFromStart(node)._2 + 1) (s,i) | |
else (pathAndDistanceFromStart(node)._1 :+ n, pathAndDistanceFromStart(node)._2 + 1) | |
} | |
} | |
// if we've found the target then we are done | |
if (!unvisited.contains(end)) { | |
pathAndDistanceFromStart(end) | |
} | |
else { | |
// next we examine the closest node | |
val nextNode = unvisited.min(Ordering.by[Node[T], Int](pathAndDistanceFromStart(_)._2)) | |
aux(nextNode) | |
} | |
} | |
} | |
println(aux(start)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment