Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
# coding:utf-8
import sys
import random
fieldstr = """
##########
#s.....###
##.##...##
##.####..#
#....##.##
#g####j.##
##########
"""
jump_to = (5,3)
xmove = [0,0,1,-1]
ymove = [1,-1,0,0]
epsilon = 0.3
learning_rate = 0.1
discount_rate = 0.99
field = [[y for y in x] for x in fieldstr.rstrip('\n').lstrip('\n').split('\n') ]
print field
Qvalues = [[[0 for k in range(4)] for j in range(len(field[0]))] for i in range(len(field))]
def argmax(sequence):
ret = 0
mx = sequence[0]
for i, x in enumerate(sequence):
if x > mx:
mx = x
ret = i
return ret
def epsilon_greedy(state, epoch = 10):
if random.random() < min(epsilon, 1.0/epoch*20):
return random.randint(0, 3)
else:
return argmax(Qvalues[state[0]][state[1]])
def observe(state, action):
# ( tuple<int, int>, int) -> (tuple<int, int>, int)
x = state[0] + xmove[action]
y = state[1] + ymove[action]
if field[x][y] == '#':
return (state, -5)
if field[x][y] == 'g':
return ((x, y), 10)
if field[x][y] == 'j':
return (jump_to, 0)
return ((x, y), 0)
def print_qvalues():
printstr = u""
arrows = u'→←↓↑'
for x, row in enumerate(Qvalues):
for y, q in enumerate(row):
s = field[x][y]
if s == '.':
printstr += arrows[argmax(q)]
else:
printstr += s
printstr += '\n'
print printstr
def print_agent(state):
printstr = ""
for x, row in enumerate(field):
for y, place in enumerate(row):
if (x, y) == state:
printstr += 'a'
else:
printstr += place
printstr += '\n'
print printstr
def main():
goal = (5, 1)
for epoch in range(1,100):
print "Epoch:", epoch
state = (1, 1)
while state != goal:
action = epsilon_greedy(state, epoch)
(s, r) = observe(state, action)
# update
Qvalues[state[0]][state[1]][action] = (1-learning_rate) * \
Qvalues[state[0]][state[1]][action] \
+ learning_rate * (r + discount_rate * max(Qvalues[s[0]][s[1]]))
state = s
print_qvalues()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment