Last active
July 21, 2020 07:18
-
-
Save RasmusFonseca/3c9c500f3a8d47cb5cc4bb3abab4a316 to your computer and use it in GitHub Desktop.
A-star search with examples
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from collections import defaultdict | |
from typing import List, Tuple | |
from math import sqrt, pow | |
I = 1000 # Infinity | |
class Problem: | |
def __init__(self, n: int): | |
self.n = n | |
def source(self) -> int: | |
pass | |
def target(self) -> int: | |
pass | |
def neighbors(self, a: int) -> List[Tuple[int, float]]: | |
""" Return indices of and distances to all neighbors of `a` """ | |
pass | |
class MiniProblem(Problem): | |
""" | |
Graph: | |
2 | |
(1) --- (2) | |
3/ / \1 | |
(0) 1/ (4) | |
4\ / /8 | |
(3) --- | |
""" | |
def __init__(self): | |
super(MiniProblem,self).__init__(5) | |
self.adj = np.matrix([ | |
[I, 3, I, 4, I], | |
[3, I, 2, I, I], | |
[I, 2, I, 1, 1], | |
[4, I, 1, I, 8], | |
[I, I, 1, 8, I]]) | |
def source(self) -> int: | |
return 0 | |
def target(self) -> int: | |
return 4 | |
def neighbors(self, a: int) -> List[Tuple[int, float]]: | |
# Indices of and distances to all neighbors of a | |
return [(b, d) for b, d in enumerate(np.nditer(self.adj[a])) if d < I] | |
def h(self, a: int) -> float: | |
return np.min(self.adj[a,:]) | |
class GridProblem(Problem): | |
def __init__(self): | |
self.rows = 20 | |
self.cols = 20 | |
super(GridProblem,self).__init__(self.rows * self.cols) | |
#self.accessible = np.matrix([ | |
# [1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
# [1,0,1,1,1,1,1,0,0,0,0,1,1,1,1,1,1,1,1,1], | |
# [1,0,1,0,1,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1], | |
# [1,1,1,0,1,1,1,0,1,0,0,0,0,0,0,0,1,1,1,1], | |
# [0,0,0,0,1,1,1,0,1,1,1,1,1,1,1,0,1,1,1,1], | |
# [1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,0,0,0,0,0], | |
# [1,1,1,1,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1], | |
# [1,1,1,0,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1], | |
# [1,1,1,0,1,0,1,1,0,1,1,1,1,0,0,0,1,1,1,1], | |
# [1,1,1,0,1,1,0,1,0,1,1,1,0,1,1,1,0,0,0,1], | |
# [1,1,1,1,0,1,1,0,0,0,0,0,1,1,1,1,1,1,0,1], | |
# [1,1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1], | |
# [1,1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1], | |
# [1,1,1,1,0,1,1,1,0,0,0,0,1,1,1,1,1,1,1,1], | |
# [1,0,1,1,0,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1], | |
# [1,1,0,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1], | |
# [1,1,1,0,1,1,1,0,1,1,1,1,1,0,1,1,1,1,1,1], | |
# [1,1,1,0,1,1,1,0,1,1,1,1,1,1,0,1,1,1,1,1], | |
# [1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,0,0,0,0,1], | |
# [1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,0,1]]) | |
self.accessible = np.matrix([ | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]]) | |
self.visited = np.zeros([self.rows, self.cols]) | |
def printVisited(self): | |
for x in range(20): | |
print(" ".join(["#" if self.accessible[x, y]==0 else ("." if self.visited[x,y] else " ") for y in range(20)])) | |
def source(self) -> int: | |
return 0 | |
def target(self) -> int: | |
return self.rows * self.cols - 1 | |
def _idxToCoord(self, a: int) -> Tuple[int, int]: | |
return (a // self.cols, a % self.cols) | |
def _coordToIdx(self, x: int, y: int) -> int: | |
return x * self.cols + y | |
def neighbors(self, a: int) -> List[Tuple[int, float]]: | |
# Indices of and distances to all neighbors of a | |
x, y = self._idxToCoord(a) | |
self.visited[x,y] = 1 | |
neighbors = [(x+1, y+0), (x-1, y+0), (x+0, y+1), (x+0, y-1)] | |
neighbors = [(x, y) for (x, y) in neighbors if x >= 0 and y >= 0 and x < self.rows and y < self.cols] | |
return [(self._coordToIdx(x,y), 1) for (x,y) in neighbors if self.accessible[x, y] == 1] | |
def h(self, a: int) -> float: | |
x, y = self._idxToCoord(a) | |
t_x, t_y = self._idxToCoord(self.target()) | |
return abs(t_x - x) + abs(t_y - y) | |
class DiagGridProblem(Problem): | |
def __init__(self): | |
self.rows = 20 | |
self.cols = 20 | |
super(DiagGridProblem,self).__init__(self.rows * self.cols) | |
self.accessible = np.matrix([ | |
[1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,0,1,1,1,1,1,0,0,0,0,1,1,1,1,1,1,1,1,1], | |
[1,0,1,0,1,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,0,1,1,1,0,1,0,0,0,0,0,0,0,1,1,1,1], | |
[0,0,0,0,1,1,1,0,1,1,1,1,1,1,1,0,1,1,1,1], | |
[1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,0,0,0,0,0], | |
[1,1,1,1,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,0,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,0,1,0,1,1,0,1,1,1,1,0,0,0,1,1,1,1], | |
[1,1,1,0,1,1,0,1,0,1,1,1,0,1,1,1,0,0,0,1], | |
[1,1,1,1,0,1,1,0,0,0,0,0,1,1,1,1,1,1,0,1], | |
[1,1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1], | |
[1,1,1,1,0,1,1,1,0,0,0,0,1,1,1,1,1,1,1,1], | |
[1,0,1,1,0,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1], | |
[1,1,0,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1], | |
[1,1,1,0,1,1,1,0,1,1,1,1,1,0,1,1,1,1,1,1], | |
[1,1,1,0,1,1,1,0,1,1,1,1,1,1,0,1,1,1,1,1], | |
[1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,0,0,0,0,1], | |
[1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,0,1]]) | |
self.visited = np.zeros([self.rows, self.cols]) | |
def printVisited(self): | |
for x in range(20): | |
print(" ".join(["#" if self.accessible[x, y]==0 else ("." if self.visited[x,y] else " ") for y in range(20)])) | |
def source(self) -> int: | |
return 0 | |
def target(self) -> int: | |
return self.rows * self.cols - 1 | |
def _idxToCoord(self, a: int) -> Tuple[int, int]: | |
return (a // self.cols, a % self.cols) | |
def _coordToIdx(self, x: int, y: int) -> int: | |
return x * self.cols + y | |
def neighbors(self, a: int) -> List[Tuple[int, float]]: | |
# Indices of and distances to all neighbors of a | |
x, y = self._idxToCoord(a) | |
self.visited[x,y] = 1 | |
sqrt2 = sqrt(2) | |
neighbors = [((x-1, y+1), sqrt(2)), ((x+0, y+1), 1.0), ((x+1, y+1), sqrt(2)), | |
((x-1, y+0), 1.0), ((x+1, y+0), 1.0), | |
((x-1, y-1), sqrt(2)), ((x+0, y-1), 1.0), ((x+1, y-1), sqrt(2))] | |
neighbors = [((x, y), d) for ((x, y), d) in neighbors if x >= 0 and y >= 0 and x < self.rows and y < self.cols] | |
return [(self._coordToIdx(x,y), d) for ((x,y),d) in neighbors if self.accessible[x, y] == 1] | |
def h(self, a: int) -> float: | |
x, y = self._idxToCoord(a) | |
t_x, t_y = self._idxToCoord(self.target()) | |
return sqrt(pow(t_x - x,2) + pow(t_y - y, 2)) | |
def astar(problem: Problem): | |
""" | |
Basic implementation of A-star that follows the pseudo-code in it's wikipedia article | |
""" | |
source = problem.source() | |
target = problem.target() | |
open = [source] | |
g_scores = defaultdict(lambda: I) | |
f_scores = defaultdict(lambda: I) | |
g_scores[source] = 0 | |
f_scores[source] = problem.h(source) | |
nodes_visited = 0 | |
while open: | |
cur = min(open, key=lambda n: f_scores[n]) | |
open.remove(cur) | |
nodes_visited += 1 | |
if cur == target: | |
print("Done. Visited {} nodes".format(nodes_visited)) | |
return g_scores[cur] | |
for n, d_n in problem.neighbors(cur): | |
if d_n >= I: continue | |
g_n = g_scores[cur] + d_n | |
if g_n < g_scores[n]: | |
g_scores[n] = g_n | |
f_scores[n] = g_n + problem.h(n) | |
if n not in open: | |
open.append(n) | |
return I | |
problem1 = MiniProblem() | |
problem2 = GridProblem() | |
problem3 = DiagGridProblem() | |
print(astar(problem3)) | |
problem3.printVisited() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment