Skip to content

Instantly share code, notes, and snippets.

@mumbleskates
Created March 8, 2016 22:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mumbleskates/b7b3bbd3924b48805087 to your computer and use it in GitHub Desktop.
Save mumbleskates/b7b3bbd3924b48805087 to your computer and use it in GitHub Desktop.
Pythonic Dijkstra path finding
# coding=utf-8
from collections import OrderedDict, namedtuple
from functools import total_ordering
from heapq import heappush, heappop
from itertools import count, zip_longest
INITIAL_START = object()
class SumTuple(tuple):
"""
A handy class for storing priority-costs. Acts just like a regular tuple, but addition
adds together corresponding elements rather than appending.
"""
def __add__(self, other):
if other == 0:
return self
if not isinstance(other, tuple):
raise TypeError("Cannot add '{0} to {1}".format(str(self), str(other)))
return SumTuple(x + y for x, y in zip_longest(self, other, fillvalue=0))
def __radd__(self, other):
return self + other
class MaxFirstSumTuple(tuple):
"""
Like SumTuple, but the first element in the tuple reduces with max instead of sum. This allows an undesirable
edge to equally taint any route that passes through it.
"""
def __add__(self, other):
if other == 0:
return self
if not isinstance(other, tuple):
raise TypeError("Cannot add '{0} to {1}".format(str(self), str(other)))
return MaxFirstSumTuple(self._adder(other))
def __radd__(self, other):
return self + other
def _adder(self, other):
it = zip_longest(self, other, fillvalue=0)
yield max(next(it))
yield from (x + y for x, y in it)
@total_ordering
class Worker(object):
"""
Worker is a class for helper objects that can transform the costs of traversing a graph on an instance basis.
Work costs will be added together directly, so recommended return types include int, float, and SumTuple.
Workers can also be used for the task of computing paths from multiple starting points, where the
point you begin will affect the cost of your traversal overall (different workers beginning at different locations).
Workers essentially conditionally transform the edge-cost into a summable value. When using workers, edgefinder
should produce a cost that DESCRIBES the work to be performed to traverse the edge, which is passed into the
perform_work function as its sole parameter. The return value of this function must then be the COST of doing the
work thus described; for instance, edgefinder should describe the distance between the edge and the neighbor,
and the worker will accept that distance and return the amount of time to travel that distance.
"""
def __init__(self, name, perform_work):
"""
:type name: str
:type perform_work: (Any) -> Any
"""
self.name = name
self.perform_work = perform_work
def __add__(self, other):
if other == 0:
return self
else:
return WorkPerformed(other, self)
def __radd__(self, other):
return self + other
def __eq__(self, other):
if isinstance(other, Worker):
return self.name == other.name
else:
raise TypeError
def __lt__(self, other):
if isinstance(other, Worker):
return self.name < other.name
else:
raise TypeError
def __str__(self):
return "Worker({})".format(self.name)
__repr__ = __str__
class WorkPerformed(namedtuple("WorkPerformed", ("cost", "worker"))):
def __add__(self, other):
if other == 0:
return self
else:
return WorkPerformed(self.cost + self.worker.perform_work(other), self.worker)
def __radd__(self, other):
return self + other
def with_initial(initial):
"""
:param initial: iterable of (start node, worker) tuples
:return: Decorate an edgefinder to start the given initial costs at the given locations.
If these initial costs are Workers, The edgefinder being decorated should normally
return edge costs that are compatible work descriptors. To use this decorator to
populate the map traversal with workers, send the constant INITIAL_START as the
starting node.
"""
def dec(edgefinder):
def new_edgefinder(node):
if node is INITIAL_START:
return initial
else:
return edgefinder(node)
return new_edgefinder
return dec
def dijkstra(start, destination, edgefinder=lambda node: ((x, 1) for x in node)):
"""
:param start: The start node
:param destination: The destination node
:param edgefinder: A function that returns an iterable of tuples
of (neighbor, distance) from the node it is passed
:return: Returns the shortest path from the start to the destination.
Only accepts one start and one end.
"""
return dijkstra_first((start,), lambda node: node == destination, edgefinder)
def dijkstra_first(starts, valid_destination, edgefinder=lambda node: ((x, 1) for x in node)):
"""
:param starts: iterable of any type, only used as keys.
:param valid_destination: a predicate function returning true for any node that is a suitable destination
:param edgefinder: A function that returns an iterable of tuples
of (neighbor, distance) from the node it is passed
:return: the shortest path from any starting node to any valid destination
"""
visited = set()
index = count()
heap = []
def process():
yield from ((0, None, seed, ()) for seed in starts)
while heap:
yield heappop(heap)
# Heap values are: distance value, a unique counter for sorting, the next place to go, and the (path, (so, (far,)))
for dist, _, node, path in process():
if node not in visited:
path = (node, path)
if valid_destination(node):
return dist, path
visited.add(node)
for neighbor, dist_to_neighbor in edgefinder(node):
if neighbor not in visited:
heappush(heap, (dist + dist_to_neighbor, next(index), neighbor, path))
return None, () # no path exists
def dijkstra_multiple(starts, valid_destination, num_to_find, edgefinder=lambda node: ((x, 1) for x in node)):
"""
:param starts: iterable of any type, only used as keys.
:param valid_destination: a predicate function returning true for any node that is a suitable destination
:param edgefinder: A function that returns an iterable of tuples
of (neighbor, distance) from the node it is passed
:return: the shortest 'num_to_find' paths from any starting node to any valid destination. Keys are the endpoint,
values are (total cost, path) tuples, and the whole result is an ordered dictionary from least to greatest
total cost.
"""
visited = set()
index = count()
heap = []
results = OrderedDict()
def process():
yield from ((0, None, seed, ()) for seed in starts)
while heap and len(results) <= num_to_find:
yield heappop(heap)
# Heap values are: distance value, a unique counter for sorting, the next place to go, and the (path, (so, (far,)))
for dist, _, node, path in process():
if node not in visited:
path = (node, path)
if valid_destination(node):
results[node] = (dist, path)
visited.add(node)
for neighbor, dist_to_neighbor in edgefinder(node):
if neighbor not in visited:
heappush(heap, (dist + dist_to_neighbor, next(index), neighbor, path))
return results
def dijkstra_set(starts, destinations, edgefinder=lambda node: ((x, 1) for x in node)):
"""
:param starts: an iterable of starting nodes
:param destinations: a set of destinations
:param edgefinder: A function that returns an iterable of tuples
of (neighbor, distance) from the node it is passed
:return: a dictionary of the shortest path from any starting node to each destination in the set
"""
return dijkstra_multiple(starts, (lambda x: x in destinations), len(destinations), edgefinder)
def dijkstra_full(starts, edgefinder=lambda node: ((x, 1) for x in node)):
"""
:param starts: iterable of any type, only used as keys.
:param edgefinder: A function that returns an iterable of tuples
of (neighbor, distance) from the node it is passed
:rtype: dict[object, (float, List[object])]
:return: the shortest 'num_to_find' paths from any starting node to any valid destination. Keys are the endpoint,
values are (total cost, path) tuples, and the whole result is an ordered dictionary from least to greatest
total cost.
"""
visited = set()
index = count()
heap = []
results = OrderedDict()
def process():
yield from ((0, None, seed, ()) for seed in starts)
while heap:
yield heappop(heap)
# Heap values are: distance value, a unique counter for sorting, the next place to go, and the (path, (so, (far,)))
for dist, _, node, path in process():
if node not in visited:
path = (node, path)
results[node] = (dist, path)
visited.add(node)
for neighbor, dist_to_neighbor in edgefinder(node):
if neighbor not in visited:
heappush(heap, (dist + dist_to_neighbor, next(index), neighbor, path))
return results
def convert_path(path):
result = []
while path:
result.append(path[0])
path = path[1]
result.reverse()
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment