Skip to content

Instantly share code, notes, and snippets.

@taotao54321
Created November 8, 2016 08:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save taotao54321/17f7319c6466eedb5c42525f8aba1409 to your computer and use it in GitHub Desktop.
Save taotao54321/17f7319c6466eedb5c42525f8aba1409 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Q学習
# 探索時e-greedy
import sys
import random
import os.path
import pprint
import numpy as np
import gym
DEBUG = False
#DEBUG = True
ENVS = {
"4x4" : "FrozenLake-v0",
"8x8" : "FrozenLake8x8-v0",
}
def argmax_multi(x):
maxval = max(x)
return tuple(idx for idx, val in enumerate(x) if val == maxval)
def error(msg):
sys.exit(msg)
class Agent:
def __init__(self, env, epsilon):
self.env = env
self.epsilon = epsilon
self.q = [[.0,.0,.0,.0] for _ in range(env.observation_space.n)]
def learn(self, alpha, gamma):
"""1エピソード学習"""
state = self.env.reset()
if DEBUG: self.env.render()
for t in range(self.env.spec.timestep_limit):
# 現在のQ関数に基づくe-greedy
act = self._e_greedy(state)
state_next, reward, done, info = self.env.step(act)
# qは0で初期化しているので、state_nextが終端状態なら
# q_next_maxは0になる
q_next_max = max(self.q[state_next])
self.q[state][act] = (1-alpha) * self.q[state][act]\
+ alpha * (reward + gamma*q_next_max)
if DEBUG:
self.env.render()
print(state_next, reward, done, info)
pprint.pprint(self.q)
if done:
return reward
else:
state = state_next
# ターン制限超過
return 0.0
def _e_greedy(self, state):
if random.random() < self.epsilon:
return self.env.action_space.sample()
else:
# 同点のケースもありうる(学習初期は全部0だからそうなる)
# その場合は同点のものからランダムに選ぶ
acts = argmax_multi(self.q[state])
return random.choice(acts)
def test(self):
"""学習結果を用いて1エピソード実行"""
state = self.env.reset()
if DEBUG: self.env.render()
for t in range(self.env.spec.timestep_limit):
act = np.argmax(self.q[state])
state, reward, done, info = self.env.step(act)
if DEBUG:
self.env.render()
print(state, reward, done, info)
if done:
return reward
# ターン制限超過
return 0.0
def usage():
error("Usage: FrozenLake-qlearning <4x4|8x8> <alpha> <gamma> <epsilon> <learn_count> <test_count> [recdir]")
def main():
if len(sys.argv) < 7: usage()
env_name = ENVS[sys.argv[1]]
alpha = float(sys.argv[2])
gamma = float(sys.argv[3])
epsilon = float(sys.argv[4])
learn_count = int(sys.argv[5])
test_count = int(sys.argv[6])
rec_dir = sys.argv[7] if len(sys.argv) >= 8 else None
print("# <{}> alpha={}, gamma={}, epsilon={}, learn_count={} test_count={}".format(
env_name, alpha, gamma, epsilon, learn_count, test_count))
env = gym.make(env_name)
print("# step-max: {}".format(env.spec.timestep_limit))
if rec_dir:
subdir = "FrozenLake{}-qlearning2-alpha{}-gamma{}-eps{}-learn{}-test{}".format(
sys.argv[1], alpha, gamma, epsilon, learn_count, test_count
)
env.monitor.start(os.path.join(rec_dir, subdir))
agent = Agent(env, epsilon)
print("##### LEARNING #####")
reward_total = 0.0
for episode in range(learn_count):
reward_total += agent.learn(alpha, gamma)
pprint.pprint(agent.q)
print("episodes: {}".format(learn_count))
print("total reward: {}".format(reward_total))
print("average reward: {:.2f}".format(reward_total / learn_count))
print("##### TEST #####")
reward_total = 0.0
for episode in range(test_count):
reward_total += agent.test()
print("episodes: {}".format(test_count))
print("total reward: {}".format(reward_total))
print("average reward: {:.2f}".format(reward_total / test_count))
if rec_dir: env.monitor.close()
if __name__ == "__main__": main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment