Last active
July 8, 2018 07:30
-
-
Save limitpointinf0/4a1b4585f6d540e9f36a8728bbaf4f18 to your computer and use it in GitHub Desktop.
maze solving with q-learning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import itertools | |
import math | |
import time | |
import os | |
def make_sqworld(num_loc=None): | |
dim = int(math.sqrt(num_loc)) | |
locs_arr=np.array(list(range(num_loc))) | |
return locs_arr.reshape((dim,dim)) | |
def get_steps(mat, exc=[]): | |
steps = [] | |
x_coords = list(range(mat.shape[0])) | |
y_coords = list(range(mat.shape[1])) | |
coords = list(itertools.product(x_coords, y_coords)) | |
for i in range(mat.shape[0]): | |
for j in range(mat.shape[1]): | |
if mat[i][j] not in exc: | |
if (i,j-1) in coords: | |
steps.append((mat[i][j], mat[i][j-1])) | |
if (i-1,j) in coords: | |
steps.append((mat[i][j], mat[i-1][j])) | |
if (i, j+1) in coords: | |
steps.append((mat[i][j], mat[i][j+1])) | |
if (i+1,j) in coords: | |
steps.append((mat[i][j], mat[i+1][j])) | |
steps = [x for x in steps if x[1] not in exc] | |
return steps | |
#steps = [(0,1),(1,0),(1,2),(1,4),(2,1),(2,5),(4,1),(4,5),(4,7),(5,4),(5,2),(5,8),(7,8),(7,4),(7,6),(8,7),(8,5),(6,7),(6,6)] | |
def initQ(dim): | |
return np.zeros((dim, dim)) | |
def initR(steps, goal): | |
dim = max([x[1] for x in steps]) | |
rmat = np.ones((dim+1, dim+1))*-1.0 | |
#print steps | |
for step in steps: | |
if step[1] == goal: | |
rmat[step[0]][step[1]] = 100 | |
else: | |
rmat[step[0]][step[1]] = 0 | |
return rmat | |
def updQ(qmat, rmat, steps, gamma=0.8): | |
Q_cop = qmat.copy() | |
R_cop = rmat.copy() | |
for i in range(Q_cop.shape[0]): | |
for j in range(Q_cop.shape[1]): | |
if (i,j) in steps: | |
next_step = [x for x in steps if x[0] == j] | |
q_list = [Q_cop[x[0]][x[1]] for x in next_step] | |
Q_cop[i][j] = R_cop[i][j] + gamma * max(q_list) | |
else: | |
Q_cop[i][j] = 0 | |
return Q_cop | |
def render_world(world, exc=[], player_loc=0, goal_loc=24): | |
printout = " o " * (world.shape[1] + 2) | |
for i in range(world.shape[0]): | |
line = '\n o ' | |
for j in range(world.shape[1]): | |
if world[i][j] in exc: | |
line += ' o ' | |
elif world[i][j] == player_loc: | |
line += ' x ' | |
elif world[i][j] == goal_loc: | |
line += ' $ ' | |
else: | |
line += ' ' | |
line += ' o ' | |
printout += line | |
printout += '\n' + " o " * (world.shape[1] + 2) | |
print printout | |
player_at = 0 | |
GOAL = 99 | |
num_squares = 100 | |
barriers = [x for x in list(set(np.random.randint(num_squares, size=25))) if x not in [player_at, GOAL]] | |
print 'creating world' | |
sq_maze = make_sqworld(num_loc=num_squares) | |
steps = get_steps(sq_maze, exc=barriers) | |
#train | |
print 'training...' | |
R = initR(steps, goal=GOAL) | |
Q = initQ(num_squares) | |
for i in range(50): | |
Q_next = updQ(Q, R, steps) | |
Q = Q_next | |
print 'finished training.' | |
os.system('clear') | |
render_world(sq_maze, exc=barriers, player_loc=player_at, goal_loc=GOAL) | |
time.sleep(2) | |
while player_at != GOAL: | |
player_at = np.argmax(Q[player_at]) | |
os.system('clear') | |
render_world(sq_maze, exc=barriers, player_loc=player_at, goal_loc=GOAL) | |
time.sleep(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment