Last active
August 29, 2015 14:19
-
-
Save YOwatari/594b4803181fa635a5cf to your computer and use it in GitHub Desktop.
q学習による迷路探索を1時間で作ってみる
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
import random | |
# 学習条件 | |
ALPHA = 0.2 # Learning raito | |
GAMMA = 0.9 # discount ratio | |
SIGMA = 0.2 # greeding ratio | |
LEARNING_COUNT = 1000 | |
class Field(object): | |
def __init__(self, start=(0, 0), field=[[0 for i in xrange(5)] for j in xrange(5)]): | |
""" | |
:param start: tuple of int | |
:param field: list of list of int | |
""" | |
self.now_point = start | |
self.field = field | |
def __str__(self): | |
""" | |
:return: str | |
""" | |
result = "" | |
for i in xrange(len(self.field)): | |
for j in xrange(len(self.field[i])): | |
if (i, j) == self.now_point: | |
result += "@".rjust(4) | |
else: | |
result += "{0:4d}".format(int(self.field[i][j])) | |
result += "\n" | |
return result | |
def get_movable_points(self, point): | |
""" | |
:param point: tuple of int | |
:return: tuple | |
""" | |
movable_points = ( | |
(point[0], point[1]-1) if point[1]-1 >= 0 else None, | |
(point[0], point[1]+1) if point[1]+1 < len(self.field) else None, | |
(point[0]-1, point[1]) if point[0]-1 >= 0 else None, | |
(point[0]+1, point[1]) if point[0]+1 < len(self.field) else None | |
) | |
return filter(lambda n: n is not None, movable_points) | |
class QLeaning(object): | |
def __init__(self, field): | |
""" | |
:param field: object | |
""" | |
self.q_value = {} | |
self.field = field | |
def learn(self, greedy=False): | |
""" | |
:param greedy: bool | |
""" | |
while True: | |
if greedy: | |
# greedy | |
action = self.chose_point_greedy(self.field.now_point) | |
print self.field | |
print "%s --> %s\n" % (self.field.now_point, action) | |
else: | |
# e-greedy | |
action = self.chose_point(self.field.now_point) | |
if self.update_Qvalue(self.field.now_point, action): | |
break | |
else: | |
self.field.now_point = action | |
def chose_point_greedy(self, point): | |
""" | |
:param point: tuple of int | |
:return: tuple of int | |
""" | |
best_points = [] | |
q_values = [] | |
for action in self.field.get_movable_points(point): | |
q_values.append(self.get_Qvalue(point, action)) | |
max_Qvalue = q_values[0] | |
for action in self.field.get_movable_points(point): | |
q = self.get_Qvalue(point, action) | |
if q > max_Qvalue: | |
best_points.append(action) | |
max_Qvalue = q | |
elif q == max_Qvalue: | |
best_points.append(action) | |
return random.choice(best_points) | |
def chose_point(self, point): | |
""" | |
greedy | |
:param point: tuple of int | |
:return: tuple of int | |
""" | |
if SIGMA < random.random(): | |
return random.choice(self.field.get_movable_points(point)) | |
else: | |
return self.chose_point_greedy(point) | |
def update_Qvalue(self, point, action): | |
""" | |
:param point: tuple of int | |
:param action: list of tuple of int | |
:return: bool | |
""" | |
Qsa = self.get_Qvalue(point, action) | |
mQsa = max([self.get_Qvalue(action, n_action) for n_action in self.field.get_movable_points(action)]) | |
rsa = self.field.field[point[0]][point[1]] | |
if rsa > 0: | |
flg = True | |
else: | |
flg = False | |
q_value = Qsa + ALPHA * (rsa + GAMMA * mQsa - Qsa) | |
self.q_value.setdefault(point, {}) | |
self.q_value[point][action] = q_value | |
return flg | |
def get_Qvalue(self, point, action): | |
""" | |
:param point: tuple of int | |
:param action: list of tuple of int | |
:return: float | |
""" | |
try: | |
return self.q_value[point][action] | |
except KeyError: | |
return 0.0 | |
def dump(self): | |
""" | |
""" | |
for i, s in enumerate(self.q_value.keys()): | |
for a in self.q_value[s].keys(): | |
print "Q(s, a): Q(%s, %s): %s" % (str(s), str(a), str(self.q_value[s][a])) | |
if i != len(self.q_value.keys())-1: | |
print '------------------------------------------' | |
if __name__ == '__main__': | |
raw_field = [[0 for i in xrange(5)] for j in xrange(5)] | |
raw_field[0][3] = -10 | |
raw_field[1][1] = -10 | |
raw_field[2][1] = -10 | |
raw_field[2][3] = -10 | |
raw_field[3][3] = -10 | |
raw_field[4][1] = -10 | |
raw_field[4][4] = 100 | |
f = Field(field=raw_field) | |
ql = QLeaning(f) | |
print ql.field | |
for i in xrange(LEARNING_COUNT): | |
ql.learn() | |
ql.dump() | |
ql.learn(True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment