Last active
May 28, 2017 04:02
-
-
Save greeness/3985998 to your computer and use it in GitHub Desktop.
q-learning example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from random import random, choice, randint | |
""" | |
+++++++++++++++++++++ | |
+ 10 + 1 + 1 + 1 + | |
+++++++++++++++++++++ | |
+ 1 + 1 + 1 + 1 + | |
+++++++++++++++++++++ | |
+ 1 + 1 + 1 + 1 + | |
+++++++++++++++++++++ | |
+ 1 + 1 + 1 + 1 + | |
+++++++++++++++++++++ | |
""" | |
WIDTH=10 | |
HEIGHT=10 | |
def init(): | |
Q = {}; visit = {} | |
ACTIONS = ['DOWN', 'UP', 'RIGHT', 'LEFT'] | |
for i in range(HEIGHT): | |
for j in range(WIDTH): | |
state = (i,j) | |
Q[state] = {} | |
visit[state] = {} | |
for action in ACTIONS: | |
if not is_action_doable(state, action): continue | |
Q[state][action] = 0. | |
visit[state][action] = 1 | |
return Q, visit | |
def get_next_state(state, a): | |
x, y = state | |
if a == 'DOWN': | |
new_state = (x+1,y) | |
elif a == 'UP': | |
new_state = (x-1,y) | |
elif a == 'LEFT': | |
new_state = (x,y-1) | |
else: #a == 'RIGHT': | |
new_state = (x,y+1) | |
return new_state | |
def is_action_doable(state, a): | |
x_new, y_new = get_next_state(state, a) | |
if x_new < 0 or x_new >= HEIGHT or y_new < 0 or y_new >= WIDTH: return False | |
return True | |
def draw_current_best(Q, visit): | |
DRAWS = {'DOWN':' v ', 'UP':' ^ ', 'RIGHT':'-->', 'LEFT':'<--'} | |
lines = '-'*79 + '\n' | |
policy = [] | |
for i in range(HEIGHT): | |
for j in range(WIDTH): | |
state = (i,j) | |
max_act = None | |
max_score = -1 | |
for a in Q[state]: | |
if Q[state][a] > max_score: | |
max_score = Q[state][a] | |
max_act = a | |
lines += DRAWS[max_act] + ' ' | |
policy.append(max_act) | |
lines += ' | ' | |
for j in range(WIDTH): | |
state = (i,j) | |
lines += "%6d " % (sum(visit[state].values())) | |
lines += '\n' | |
return policy, lines | |
def get_reward((x,y)): | |
if x == 0 and y == 0: return 100 | |
return 0 | |
def get_candidates(state, Q): | |
actions = [] | |
for action in Q[state]: | |
new_state = get_next_state(state, action) | |
q_value = max(Q[new_state].values()) | |
actions.append((q_value, get_reward(new_state), new_state, action)) | |
return sorted(actions, reverse=True) | |
def choose_random_action(actions, epsilon=0.1): | |
if random() > epsilon: | |
return actions[0] | |
return choice(actions) | |
def update(state, action, next_state, r, Q, visit): | |
""" | |
Q(S, A) = Q(S, A) + alpha * [R + MaxQ(S', A') - Q(S, A)] | |
""" | |
maxq = max(Q[next_state].values()) | |
alpha = 1./(visit[state][action]) | |
#alpha = 0.01 | |
Q[state][action] = Q[state][action] + alpha * (r + maxq - Q[state][action]) | |
return Q | |
def main(): | |
steps = 0 | |
policy = [] | |
longest_non_change = 0 | |
non_change = 0 | |
Q, visit = init() | |
while True: | |
state = choice(Q.keys()) # random starting point | |
steps += 1 | |
actions = get_candidates(state, Q) | |
q_value, r, next_state, action = choose_random_action(actions, epsilon=.5) | |
Q = update(state, action, next_state, r, Q, visit) | |
visit[state][action] += 1 | |
new_policy, lines = draw_current_best(Q, visit) | |
if new_policy == policy: | |
non_change += 1 | |
if non_change >= 10000: | |
print "no change %d times in a row, total steps played: %d\n%s" % (non_change, steps, lines) | |
break | |
else: | |
longest_non_change = max(non_change, longest_non_change) | |
non_change = 0 | |
policy = new_policy | |
if steps % 10000 == 0: | |
print "no change %d times in a row, total steps played: %d\n%s" % (longest_non_change, steps, lines) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment