Skip to content

Instantly share code, notes, and snippets.

@YOwatari
Last active August 29, 2015 14:19
Show Gist options
  • Save YOwatari/594b4803181fa635a5cf to your computer and use it in GitHub Desktop.
Save YOwatari/594b4803181fa635a5cf to your computer and use it in GitHub Desktop.
q学習による迷路探索を1時間で作ってみる
#!/usr/bin/env python
# coding: utf-8
import random
# 学習条件
ALPHA = 0.2 # Learning raito
GAMMA = 0.9 # discount ratio
SIGMA = 0.2 # greeding ratio
LEARNING_COUNT = 1000
class Field(object):
def __init__(self, start=(0, 0), field=[[0 for i in xrange(5)] for j in xrange(5)]):
"""
:param start: tuple of int
:param field: list of list of int
"""
self.now_point = start
self.field = field
def __str__(self):
"""
:return: str
"""
result = ""
for i in xrange(len(self.field)):
for j in xrange(len(self.field[i])):
if (i, j) == self.now_point:
result += "@".rjust(4)
else:
result += "{0:4d}".format(int(self.field[i][j]))
result += "\n"
return result
def get_movable_points(self, point):
"""
:param point: tuple of int
:return: tuple
"""
movable_points = (
(point[0], point[1]-1) if point[1]-1 >= 0 else None,
(point[0], point[1]+1) if point[1]+1 < len(self.field) else None,
(point[0]-1, point[1]) if point[0]-1 >= 0 else None,
(point[0]+1, point[1]) if point[0]+1 < len(self.field) else None
)
return filter(lambda n: n is not None, movable_points)
class QLeaning(object):
def __init__(self, field):
"""
:param field: object
"""
self.q_value = {}
self.field = field
def learn(self, greedy=False):
"""
:param greedy: bool
"""
while True:
if greedy:
# greedy
action = self.chose_point_greedy(self.field.now_point)
print self.field
print "%s --> %s\n" % (self.field.now_point, action)
else:
# e-greedy
action = self.chose_point(self.field.now_point)
if self.update_Qvalue(self.field.now_point, action):
break
else:
self.field.now_point = action
def chose_point_greedy(self, point):
"""
:param point: tuple of int
:return: tuple of int
"""
best_points = []
q_values = []
for action in self.field.get_movable_points(point):
q_values.append(self.get_Qvalue(point, action))
max_Qvalue = q_values[0]
for action in self.field.get_movable_points(point):
q = self.get_Qvalue(point, action)
if q > max_Qvalue:
best_points.append(action)
max_Qvalue = q
elif q == max_Qvalue:
best_points.append(action)
return random.choice(best_points)
def chose_point(self, point):
"""
greedy
:param point: tuple of int
:return: tuple of int
"""
if SIGMA < random.random():
return random.choice(self.field.get_movable_points(point))
else:
return self.chose_point_greedy(point)
def update_Qvalue(self, point, action):
"""
:param point: tuple of int
:param action: list of tuple of int
:return: bool
"""
Qsa = self.get_Qvalue(point, action)
mQsa = max([self.get_Qvalue(action, n_action) for n_action in self.field.get_movable_points(action)])
rsa = self.field.field[point[0]][point[1]]
if rsa > 0:
flg = True
else:
flg = False
q_value = Qsa + ALPHA * (rsa + GAMMA * mQsa - Qsa)
self.q_value.setdefault(point, {})
self.q_value[point][action] = q_value
return flg
def get_Qvalue(self, point, action):
"""
:param point: tuple of int
:param action: list of tuple of int
:return: float
"""
try:
return self.q_value[point][action]
except KeyError:
return 0.0
def dump(self):
"""
"""
for i, s in enumerate(self.q_value.keys()):
for a in self.q_value[s].keys():
print "Q(s, a): Q(%s, %s): %s" % (str(s), str(a), str(self.q_value[s][a]))
if i != len(self.q_value.keys())-1:
print '------------------------------------------'
if __name__ == '__main__':
raw_field = [[0 for i in xrange(5)] for j in xrange(5)]
raw_field[0][3] = -10
raw_field[1][1] = -10
raw_field[2][1] = -10
raw_field[2][3] = -10
raw_field[3][3] = -10
raw_field[4][1] = -10
raw_field[4][4] = 100
f = Field(field=raw_field)
ql = QLeaning(f)
print ql.field
for i in xrange(LEARNING_COUNT):
ql.learn()
ql.dump()
ql.learn(True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment