Skip to content

Instantly share code, notes, and snippets.

@RasmusFonseca
Last active July 21, 2020 07:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RasmusFonseca/3c9c500f3a8d47cb5cc4bb3abab4a316 to your computer and use it in GitHub Desktop.
Save RasmusFonseca/3c9c500f3a8d47cb5cc4bb3abab4a316 to your computer and use it in GitHub Desktop.
A-star search with examples
import numpy as np
from collections import defaultdict
from typing import List, Tuple
from math import sqrt, pow
I = 1000 # Infinity
class Problem:
def __init__(self, n: int):
self.n = n
def source(self) -> int:
pass
def target(self) -> int:
pass
def neighbors(self, a: int) -> List[Tuple[int, float]]:
""" Return indices of and distances to all neighbors of `a` """
pass
class MiniProblem(Problem):
"""
Graph:
2
(1) --- (2)
3/ / \1
(0) 1/ (4)
4\ / /8
(3) ---
"""
def __init__(self):
super(MiniProblem,self).__init__(5)
self.adj = np.matrix([
[I, 3, I, 4, I],
[3, I, 2, I, I],
[I, 2, I, 1, 1],
[4, I, 1, I, 8],
[I, I, 1, 8, I]])
def source(self) -> int:
return 0
def target(self) -> int:
return 4
def neighbors(self, a: int) -> List[Tuple[int, float]]:
# Indices of and distances to all neighbors of a
return [(b, d) for b, d in enumerate(np.nditer(self.adj[a])) if d < I]
def h(self, a: int) -> float:
return np.min(self.adj[a,:])
class GridProblem(Problem):
def __init__(self):
self.rows = 20
self.cols = 20
super(GridProblem,self).__init__(self.rows * self.cols)
#self.accessible = np.matrix([
# [1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
# [1,0,1,1,1,1,1,0,0,0,0,1,1,1,1,1,1,1,1,1],
# [1,0,1,0,1,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1],
# [1,1,1,0,1,1,1,0,1,0,0,0,0,0,0,0,1,1,1,1],
# [0,0,0,0,1,1,1,0,1,1,1,1,1,1,1,0,1,1,1,1],
# [1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,0,0,0,0,0],
# [1,1,1,1,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1],
# [1,1,1,0,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1],
# [1,1,1,0,1,0,1,1,0,1,1,1,1,0,0,0,1,1,1,1],
# [1,1,1,0,1,1,0,1,0,1,1,1,0,1,1,1,0,0,0,1],
# [1,1,1,1,0,1,1,0,0,0,0,0,1,1,1,1,1,1,0,1],
# [1,1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1],
# [1,1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1],
# [1,1,1,1,0,1,1,1,0,0,0,0,1,1,1,1,1,1,1,1],
# [1,0,1,1,0,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1],
# [1,1,0,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1],
# [1,1,1,0,1,1,1,0,1,1,1,1,1,0,1,1,1,1,1,1],
# [1,1,1,0,1,1,1,0,1,1,1,1,1,1,0,1,1,1,1,1],
# [1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,0,0,0,0,1],
# [1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,0,1]])
self.accessible = np.matrix([
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]])
self.visited = np.zeros([self.rows, self.cols])
def printVisited(self):
for x in range(20):
print(" ".join(["#" if self.accessible[x, y]==0 else ("." if self.visited[x,y] else " ") for y in range(20)]))
def source(self) -> int:
return 0
def target(self) -> int:
return self.rows * self.cols - 1
def _idxToCoord(self, a: int) -> Tuple[int, int]:
return (a // self.cols, a % self.cols)
def _coordToIdx(self, x: int, y: int) -> int:
return x * self.cols + y
def neighbors(self, a: int) -> List[Tuple[int, float]]:
# Indices of and distances to all neighbors of a
x, y = self._idxToCoord(a)
self.visited[x,y] = 1
neighbors = [(x+1, y+0), (x-1, y+0), (x+0, y+1), (x+0, y-1)]
neighbors = [(x, y) for (x, y) in neighbors if x >= 0 and y >= 0 and x < self.rows and y < self.cols]
return [(self._coordToIdx(x,y), 1) for (x,y) in neighbors if self.accessible[x, y] == 1]
def h(self, a: int) -> float:
x, y = self._idxToCoord(a)
t_x, t_y = self._idxToCoord(self.target())
return abs(t_x - x) + abs(t_y - y)
class DiagGridProblem(Problem):
def __init__(self):
self.rows = 20
self.cols = 20
super(DiagGridProblem,self).__init__(self.rows * self.cols)
self.accessible = np.matrix([
[1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
[1,0,1,1,1,1,1,0,0,0,0,1,1,1,1,1,1,1,1,1],
[1,0,1,0,1,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,0,1,1,1,0,1,0,0,0,0,0,0,0,1,1,1,1],
[0,0,0,0,1,1,1,0,1,1,1,1,1,1,1,0,1,1,1,1],
[1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,0,0,0,0,0],
[1,1,1,1,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,0,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,0,1,0,1,1,0,1,1,1,1,0,0,0,1,1,1,1],
[1,1,1,0,1,1,0,1,0,1,1,1,0,1,1,1,0,0,0,1],
[1,1,1,1,0,1,1,0,0,0,0,0,1,1,1,1,1,1,0,1],
[1,1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,0,1,1,1,0,0,0,0,1,1,1,1,1,1,1,1],
[1,0,1,1,0,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1],
[1,1,0,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1],
[1,1,1,0,1,1,1,0,1,1,1,1,1,0,1,1,1,1,1,1],
[1,1,1,0,1,1,1,0,1,1,1,1,1,1,0,1,1,1,1,1],
[1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,0,0,0,0,1],
[1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,0,1]])
self.visited = np.zeros([self.rows, self.cols])
def printVisited(self):
for x in range(20):
print(" ".join(["#" if self.accessible[x, y]==0 else ("." if self.visited[x,y] else " ") for y in range(20)]))
def source(self) -> int:
return 0
def target(self) -> int:
return self.rows * self.cols - 1
def _idxToCoord(self, a: int) -> Tuple[int, int]:
return (a // self.cols, a % self.cols)
def _coordToIdx(self, x: int, y: int) -> int:
return x * self.cols + y
def neighbors(self, a: int) -> List[Tuple[int, float]]:
# Indices of and distances to all neighbors of a
x, y = self._idxToCoord(a)
self.visited[x,y] = 1
sqrt2 = sqrt(2)
neighbors = [((x-1, y+1), sqrt(2)), ((x+0, y+1), 1.0), ((x+1, y+1), sqrt(2)),
((x-1, y+0), 1.0), ((x+1, y+0), 1.0),
((x-1, y-1), sqrt(2)), ((x+0, y-1), 1.0), ((x+1, y-1), sqrt(2))]
neighbors = [((x, y), d) for ((x, y), d) in neighbors if x >= 0 and y >= 0 and x < self.rows and y < self.cols]
return [(self._coordToIdx(x,y), d) for ((x,y),d) in neighbors if self.accessible[x, y] == 1]
def h(self, a: int) -> float:
x, y = self._idxToCoord(a)
t_x, t_y = self._idxToCoord(self.target())
return sqrt(pow(t_x - x,2) + pow(t_y - y, 2))
def astar(problem: Problem):
"""
Basic implementation of A-star that follows the pseudo-code in it's wikipedia article
"""
source = problem.source()
target = problem.target()
open = [source]
g_scores = defaultdict(lambda: I)
f_scores = defaultdict(lambda: I)
g_scores[source] = 0
f_scores[source] = problem.h(source)
nodes_visited = 0
while open:
cur = min(open, key=lambda n: f_scores[n])
open.remove(cur)
nodes_visited += 1
if cur == target:
print("Done. Visited {} nodes".format(nodes_visited))
return g_scores[cur]
for n, d_n in problem.neighbors(cur):
if d_n >= I: continue
g_n = g_scores[cur] + d_n
if g_n < g_scores[n]:
g_scores[n] = g_n
f_scores[n] = g_n + problem.h(n)
if n not in open:
open.append(n)
return I
problem1 = MiniProblem()
problem2 = GridProblem()
problem3 = DiagGridProblem()
print(astar(problem3))
problem3.printVisited()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment