Skip to content

Instantly share code, notes, and snippets.

@greeness
Last active May 28, 2017 04:02
Show Gist options
  • Save greeness/3985998 to your computer and use it in GitHub Desktop.
Save greeness/3985998 to your computer and use it in GitHub Desktop.
q-learning example
from random import random, choice, randint
"""
+++++++++++++++++++++
+ 10 + 1 + 1 + 1 +
+++++++++++++++++++++
+ 1 + 1 + 1 + 1 +
+++++++++++++++++++++
+ 1 + 1 + 1 + 1 +
+++++++++++++++++++++
+ 1 + 1 + 1 + 1 +
+++++++++++++++++++++
"""
WIDTH=10
HEIGHT=10
def init():
Q = {}; visit = {}
ACTIONS = ['DOWN', 'UP', 'RIGHT', 'LEFT']
for i in range(HEIGHT):
for j in range(WIDTH):
state = (i,j)
Q[state] = {}
visit[state] = {}
for action in ACTIONS:
if not is_action_doable(state, action): continue
Q[state][action] = 0.
visit[state][action] = 1
return Q, visit
def get_next_state(state, a):
x, y = state
if a == 'DOWN':
new_state = (x+1,y)
elif a == 'UP':
new_state = (x-1,y)
elif a == 'LEFT':
new_state = (x,y-1)
else: #a == 'RIGHT':
new_state = (x,y+1)
return new_state
def is_action_doable(state, a):
x_new, y_new = get_next_state(state, a)
if x_new < 0 or x_new >= HEIGHT or y_new < 0 or y_new >= WIDTH: return False
return True
def draw_current_best(Q, visit):
DRAWS = {'DOWN':' v ', 'UP':' ^ ', 'RIGHT':'-->', 'LEFT':'<--'}
lines = '-'*79 + '\n'
policy = []
for i in range(HEIGHT):
for j in range(WIDTH):
state = (i,j)
max_act = None
max_score = -1
for a in Q[state]:
if Q[state][a] > max_score:
max_score = Q[state][a]
max_act = a
lines += DRAWS[max_act] + ' '
policy.append(max_act)
lines += ' | '
for j in range(WIDTH):
state = (i,j)
lines += "%6d " % (sum(visit[state].values()))
lines += '\n'
return policy, lines
def get_reward((x,y)):
if x == 0 and y == 0: return 100
return 0
def get_candidates(state, Q):
actions = []
for action in Q[state]:
new_state = get_next_state(state, action)
q_value = max(Q[new_state].values())
actions.append((q_value, get_reward(new_state), new_state, action))
return sorted(actions, reverse=True)
def choose_random_action(actions, epsilon=0.1):
if random() > epsilon:
return actions[0]
return choice(actions)
def update(state, action, next_state, r, Q, visit):
"""
Q(S, A) = Q(S, A) + alpha * [R + MaxQ(S', A') - Q(S, A)]
"""
maxq = max(Q[next_state].values())
alpha = 1./(visit[state][action])
#alpha = 0.01
Q[state][action] = Q[state][action] + alpha * (r + maxq - Q[state][action])
return Q
def main():
steps = 0
policy = []
longest_non_change = 0
non_change = 0
Q, visit = init()
while True:
state = choice(Q.keys()) # random starting point
steps += 1
actions = get_candidates(state, Q)
q_value, r, next_state, action = choose_random_action(actions, epsilon=.5)
Q = update(state, action, next_state, r, Q, visit)
visit[state][action] += 1
new_policy, lines = draw_current_best(Q, visit)
if new_policy == policy:
non_change += 1
if non_change >= 10000:
print "no change %d times in a row, total steps played: %d\n%s" % (non_change, steps, lines)
break
else:
longest_non_change = max(non_change, longest_non_change)
non_change = 0
policy = new_policy
if steps % 10000 == 0:
print "no change %d times in a row, total steps played: %d\n%s" % (longest_non_change, steps, lines)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment