Created
January 13, 2021 11:53
-
-
Save trygvea/6067a744ee67c2f0447c3c7f5b715d62 to your computer and use it in GitHub Desktop.
Djikstras shortest path algorithm in kotlin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package algorithms.shortestpath | |
interface Node | |
data class Edge(val node1: Node, val node2: Node, val distance: Int) | |
/** | |
* See https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm | |
*/ | |
fun findShortestPath(edges: List<Edge>, source: Node, target: Node): ShortestPathResult { | |
// Note: this implementation uses similar variable names as the algorithm given do. | |
// We found it more important to align with the algorithm than to use possibly more sensible naming. | |
val dist = mutableMapOf<Node, Int>() | |
val prev = mutableMapOf<Node, Node?>() | |
val q = findDistinctNodes(edges) | |
q.forEach { v -> | |
dist[v] = Integer.MAX_VALUE | |
prev[v] = null | |
} | |
dist[source] = 0 | |
while (q.isNotEmpty()) { | |
val u = q.minByOrNull { dist[it] ?: 0 } | |
q.remove(u) | |
if (u == target) { | |
break // Found shortest path to target | |
} | |
edges | |
.filter { it.node1 == u } | |
.forEach { edge -> | |
val v = edge.node2 | |
val alt = (dist[u] ?: 0) + edge.distance | |
if (alt < (dist[v] ?: 0)) { | |
dist[v] = alt | |
prev[v] = u | |
} | |
} | |
} | |
return ShortestPathResult(prev, dist, source, target) | |
} | |
private fun findDistinctNodes(edges: List<Edge>): MutableSet<Node> { | |
val nodes = mutableSetOf<Node>() | |
edges.forEach { | |
nodes.add(it.node1) | |
nodes.add(it.node2) | |
} | |
return nodes | |
} | |
/** | |
* Traverse result | |
*/ | |
class ShortestPathResult(val prev: Map<Node, Node?>, val dist: Map<Node, Int>, val source: Node, val target: Node) { | |
fun shortestPath(from: Node = source, to: Node = target, list: List<Node> = emptyList()): List<Node> { | |
val last = prev[to] ?: return if (from == to) { | |
list + to | |
} else { | |
emptyList() | |
} | |
return shortestPath(from, last, list) + to | |
} | |
fun shortestDistance(): Int? { | |
val shortest = dist[target] | |
if (shortest == Integer.MAX_VALUE) { | |
return null | |
} | |
return shortest | |
} | |
} | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package algorithms.shortestpath | |
import org.amshove.kluent.shouldBeEqualTo | |
import org.junit.jupiter.api.Test | |
data class StringNode(val s: String) : Node | |
class GenericShortestPathTest { | |
@Test | |
fun `should find shortest path`() { | |
val graph = listOf( | |
Edge(StringNode("a"), StringNode("b"), 4), | |
Edge(StringNode("a"), StringNode("c"), 2), | |
Edge(StringNode("b"), StringNode("c"), 3), | |
Edge(StringNode("c"), StringNode("b"), 1), | |
Edge(StringNode("c"), StringNode("d"), 5), | |
Edge(StringNode("b"), StringNode("d"), 1), | |
Edge(StringNode("a"), StringNode("e"), 1), | |
Edge(StringNode("e"), StringNode("d"), 4) | |
) | |
val result = findShortestPath(graph, StringNode("a"), StringNode("d")) | |
// println("prev: ${result.prev}") | |
// println("dist: ${result.dist}") | |
result.shortestPath() shouldBeEqualTo listOf(StringNode("a"), StringNode("c"), StringNode("b"), StringNode("d")) | |
result.shortestDistance() shouldBeEqualTo 4 | |
} | |
@Test | |
fun `should behave when shortest path is not reachable`() { | |
val graph = listOf( | |
Edge(StringNode("a"), StringNode("b"), 4), | |
Edge(StringNode("a"), StringNode("c"), 2), | |
Edge(StringNode("b"), StringNode("c"), 3), | |
Edge(StringNode("c"), StringNode("b"), 1), | |
Edge(StringNode("c"), StringNode("d"), 5), | |
Edge(StringNode("b"), StringNode("d"), 1), | |
// Edge(StringNode("a"), StringNode("e"), 1), | |
Edge(StringNode("e"), StringNode("d"), 4) | |
) | |
val result = findShortestPath(graph, StringNode("a"), StringNode("e")) | |
result.shortestPath() shouldBeEqualTo emptyList() | |
result.shortestDistance() shouldBeEqualTo null | |
} | |
@Test | |
fun `should behave when to-node doesnt event exist`() { | |
val graph = listOf( | |
Edge(StringNode("a"), StringNode("b"), 4), | |
Edge(StringNode("a"), StringNode("c"), 2), | |
Edge(StringNode("b"), StringNode("c"), 3), | |
Edge(StringNode("c"), StringNode("b"), 1), | |
Edge(StringNode("c"), StringNode("d"), 5), | |
Edge(StringNode("b"), StringNode("d"), 1), | |
Edge(StringNode("a"), StringNode("e"), 1), | |
Edge(StringNode("e"), StringNode("d"), 4) | |
) | |
val result = findShortestPath(graph, StringNode("a"), StringNode("f")) | |
result.shortestPath() shouldBeEqualTo emptyList() | |
result.shortestDistance() shouldBeEqualTo null | |
} | |
@Test | |
fun `should behave when the world is empty`() { | |
val graph = emptyList<Edge>() | |
val result = findShortestPath(graph, StringNode("a"), StringNode("f")) | |
result.shortestPath() shouldBeEqualTo emptyList() | |
result.shortestDistance() shouldBeEqualTo null | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment