Skip to content

Instantly share code, notes, and snippets.

@limitpointinf0
Last active July 8, 2018 07:30
Show Gist options
  • Save limitpointinf0/4a1b4585f6d540e9f36a8728bbaf4f18 to your computer and use it in GitHub Desktop.
Save limitpointinf0/4a1b4585f6d540e9f36a8728bbaf4f18 to your computer and use it in GitHub Desktop.
maze solving with q-learning
import numpy as np
import itertools
import math
import time
import os
def make_sqworld(num_loc=None):
dim = int(math.sqrt(num_loc))
locs_arr=np.array(list(range(num_loc)))
return locs_arr.reshape((dim,dim))
def get_steps(mat, exc=[]):
steps = []
x_coords = list(range(mat.shape[0]))
y_coords = list(range(mat.shape[1]))
coords = list(itertools.product(x_coords, y_coords))
for i in range(mat.shape[0]):
for j in range(mat.shape[1]):
if mat[i][j] not in exc:
if (i,j-1) in coords:
steps.append((mat[i][j], mat[i][j-1]))
if (i-1,j) in coords:
steps.append((mat[i][j], mat[i-1][j]))
if (i, j+1) in coords:
steps.append((mat[i][j], mat[i][j+1]))
if (i+1,j) in coords:
steps.append((mat[i][j], mat[i+1][j]))
steps = [x for x in steps if x[1] not in exc]
return steps
#steps = [(0,1),(1,0),(1,2),(1,4),(2,1),(2,5),(4,1),(4,5),(4,7),(5,4),(5,2),(5,8),(7,8),(7,4),(7,6),(8,7),(8,5),(6,7),(6,6)]
def initQ(dim):
return np.zeros((dim, dim))
def initR(steps, goal):
dim = max([x[1] for x in steps])
rmat = np.ones((dim+1, dim+1))*-1.0
#print steps
for step in steps:
if step[1] == goal:
rmat[step[0]][step[1]] = 100
else:
rmat[step[0]][step[1]] = 0
return rmat
def updQ(qmat, rmat, steps, gamma=0.8):
Q_cop = qmat.copy()
R_cop = rmat.copy()
for i in range(Q_cop.shape[0]):
for j in range(Q_cop.shape[1]):
if (i,j) in steps:
next_step = [x for x in steps if x[0] == j]
q_list = [Q_cop[x[0]][x[1]] for x in next_step]
Q_cop[i][j] = R_cop[i][j] + gamma * max(q_list)
else:
Q_cop[i][j] = 0
return Q_cop
def render_world(world, exc=[], player_loc=0, goal_loc=24):
printout = " o " * (world.shape[1] + 2)
for i in range(world.shape[0]):
line = '\n o '
for j in range(world.shape[1]):
if world[i][j] in exc:
line += ' o '
elif world[i][j] == player_loc:
line += ' x '
elif world[i][j] == goal_loc:
line += ' $ '
else:
line += ' '
line += ' o '
printout += line
printout += '\n' + " o " * (world.shape[1] + 2)
print printout
player_at = 0
GOAL = 99
num_squares = 100
barriers = [x for x in list(set(np.random.randint(num_squares, size=25))) if x not in [player_at, GOAL]]
print 'creating world'
sq_maze = make_sqworld(num_loc=num_squares)
steps = get_steps(sq_maze, exc=barriers)
#train
print 'training...'
R = initR(steps, goal=GOAL)
Q = initQ(num_squares)
for i in range(50):
Q_next = updQ(Q, R, steps)
Q = Q_next
print 'finished training.'
os.system('clear')
render_world(sq_maze, exc=barriers, player_loc=player_at, goal_loc=GOAL)
time.sleep(2)
while player_at != GOAL:
player_at = np.argmax(Q[player_at])
os.system('clear')
render_world(sq_maze, exc=barriers, player_loc=player_at, goal_loc=GOAL)
time.sleep(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment