Skip to content

Instantly share code, notes, and snippets.

@prophile
Created January 7, 2017 20:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save prophile/5c72af00cd4a392c621fe365ae6d3965 to your computer and use it in GitHub Desktop.
Save prophile/5c72af00cd4a392c621fe365ae6d3965 to your computer and use it in GitHub Desktop.
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
import heapq
class AStar(Iterable):
def __init__(self, transition, heuristic, start, zero_score=0, pure_heuristic=False):
self.transition = transition
self.heuristic = heuristic
self.pure_heuristic = pure_heuristic
self.candidate = start
self.candidate_path = []
self.candidate_path_length = zero_score
self.candidate_score = self._score(zero_score, start)
self.open_set = []
self.closed_set = set()
self.index = 0
def __iter__(self):
return self
def _get_candidate(self):
try:
(score, plength, _, path, node) = heapq.heappop(self.open_set)
except IndexError:
self.candidate = None
else:
self.candidate = node
self.candidate_path = path
self.candidate_path_length = plength
self.candidate_score = score
def _next_index(self):
ix = self.index
self.index = ix + 1
return ix
def _score(self, path_length, node):
if self.pure_heuristic:
return self.heuristic(node)
return path_length + self.heuristic(node)
def next(self):
while self.candidate in self.closed_set:
self._get_candidate()
if self.candidate is None:
raise StopIteration()
self.closed_set.add(self.candidate)
transitions = [
(node, transition, move_cost)
for node, transition, move_cost in self.transition(self.candidate)
if node not in self.closed_set
]
prev_node = self.candidate
prev_path = list(self.candidate_path)
if not transitions:
# Dead end: just pull in the next candidate from the open set
self._get_candidate()
return prev_node, prev_path
# Special-case for fast descent
if len(transitions) == 1:
node, transition, move_cost = transitions[0]
score = self._score(
self.candidate_path_length + move_cost,
node,
)
if score <= self.candidate_score:
# Explore this node next without faffing with the open set
self.candidate = node
self.candidate_score = score
self.candidate_path.append(transition)
self.candidate_path_length += move_cost
return prev_node, prev_path
for ix, (node, transition, move_cost) in enumerate(transitions):
score = self._score(
self.candidate_path_length + move_cost,
node,
)
index = self._next_index()
entry = (
score,
self.candidate_path_length + move_cost,
index,
self.candidate_path + [transition],
node,
)
if ix == len(transitions) - 1:
# Last node, use pushpop to avoid an extra heap op
next_candidate = heapq.heappushpop(self.open_set, entry)
(
self.candidate_score,
self.candidate_path_length,
_,
self.candidate_path,
self.candidate,
) = next_candidate
else:
heapq.heappush(self.open_set, entry)
return prev_node, prev_path
__next__ = next
def astar(
transition,
heuristic,
start,
is_final,
open_set_limits=None,
**kwargs
):
instance = AStar(transition, heuristic, start, **kwargs)
for node, path in instance:
if is_final(node):
return node, path
# Apply open set policy
if open_set_limits is not None:
max_open_set_size, trunc_open_set_size = open_set_limits
if len(instance.open_set) > max_open_set_size:
instance.open_set.sort()
instance.open_set[trunc_open_set_size:] = []
raise ValueError("No path found")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment