Skip to content

Instantly share code, notes, and snippets.

@codefever
Created October 4, 2019 09:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save codefever/0fe365c1aa0b28fe41d1388d05ea1f44 to your computer and use it in GitHub Desktop.
Save codefever/0fe365c1aa0b28fe41d1388d05ea1f44 to your computer and use it in GitHub Desktop.
BFS with Q-Learning?
#!/usr/bin/env python
from enum import IntEnum
import math
import random
import logging
logging.basicConfig(level=logging.INFO)
GRID = [
[0,0,0,0,0],
[1,1,1,0,1],
[1,0,0,0,1],
[0,0,1,1,1],
[1,0,0,0,0],
]
class Direction(IntEnum):
LEFT = 0
RIGHT = 1
TOP = 2
DOWN = 3
DIRS = {
Direction.LEFT: (0,-1),
Direction.RIGHT: (0,1),
Direction.TOP: (-1,0),
Direction.DOWN: (1,0),
}
def _get_dir(src, dst):
if src[1] != dst[1]:
return Direction.LEFT if src[1] - dst[1] == 1 else Direction.RIGHT
else:
return Direction.TOP if src[0] - dst[0] == 1 else Direction.DOWN
def neighbors(grid, y, x):
for d in DIRS.values():
yy = y + d[0]
xx = x + d[1]
if yy >= 0 and xx >= 0 and yy < len(grid) and xx < len(grid[0]):
if grid[yy][xx] > 0: continue
yield yy, xx
def bfs(grid):
start = (0, 0)
end = (len(grid)-1, len(grid[0])-1)
visited = [[False for _ in range(len(grid[0]))] for _ in range(len(grid))]
visited[start[0]][start[1]] = True
q = [start]
dist = 0
while q:
new_q = []
while q:
origin, q = q[0], q[1:]
if origin == end:
return dist
for nn in neighbors(grid, origin[0], origin[1]):
if visited[nn[0]][nn[1]]: continue
visited[nn[0]][nn[1]] = True
new_q.append(nn)
dist += 1
q = new_q
return -1
def qlearning(grid):
m, n = len(grid), len(grid[0])
end = (m-1, n-1)
start = (0, 0)
qtable = [[[0.0]*len(Direction) for _ in range(n)] for _ in range(m)]
def reward(node, todir):
d = DIRS[todir]
if (node[0] + d[0], node[1] + d[1]) == end:
# if it could reach the end, it would gain a greate bounus.
return m*n
return 1.0
def get_path():
node = start
path = []
steps = 0
max_steps = m*n
while node != end:
val = -float('infinity')
candidate = None
todir = None
for nn in neighbors(grid, node[0], node[1]):
d = _get_dir(node, nn)
if qtable[node[0]][node[1]][d.value] > val:
val = qtable[node[0]][node[1]][d.value]
candidate = nn
todir = d
path.append(todir.name[0])
node = candidate
steps += 1
if steps >= max_steps:
return None
return ''.join(path)
gramma = 0.777
alpha = 0.5
num_episodes = 20
for i in range(num_episodes):
node = start
steps = 0
eprate = math.exp(-i*0.1) # exploit if >
while steps < 100 and node != end:
valid_nn = list(neighbors(grid, node[0], node[1]))
if random.random() > eprate: # exploitation
val = -float('infinity')
todir = None
for tmp in valid_nn:
tmp_dir = _get_dir(node, tmp)
if qtable[node[0]][node[1]][tmp_dir] > val:
nn = tmp
todir = tmp_dir
else: # explorartion
nn = random.choice(valid_nn)
todir = _get_dir(node, nn)
# update table
valid_tmps = [qtable[nn[0]][nn[1]][_get_dir(nn, tmp).value] for tmp in neighbors(grid, nn[0], nn[1])]
qtable[node[0]][node[1]][todir.value] += alpha * (reward(node, todir) + gramma * max(valid_tmps) - qtable[node[0]][node[1]][todir.value])
# assert nn != node
node = nn
steps += 1
path = get_path()
logging.debug('episode[{}] exploit=[{}] reach_end=[{}], path=[{}], steps=[{}]'.format(i, eprate, node==end, path, len(path) if path else 'INF'))
path = get_path()
return len(path) if path else -1
if __name__ == '__main__':
print(bfs(GRID))
print(qlearning(GRID))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment