Created
November 1, 2017 05:08
-
-
Save dmiyakawa/5ed99b3f25facec0e7bc09f85a22b45e to your computer and use it in GitHub Desktop.
carpole + Q学習
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# | |
# http://neuro-educator.com/rl1/ | |
# | |
import gym # 倒立振子(cartpole)の実行環境 | |
import numpy as np | |
import time | |
# カート位置 -2.4~2.4 | |
# カート速度 -3.0~3.0 | |
# 棒の角度 -41.8~41.8 | |
# 棒の角速度 -2.0~2.0 | |
# | |
# 以上の状態を6分割する | |
# 6分割^(4変数)にデジタル変換してQ関数(表)を作成する | |
num_dizitized = 6 | |
# [1]Q関数を離散化して定義する関数 ------------ | |
# 観測した状態を離散値にデジタル変換する | |
def bins(clip_min, clip_max, num): | |
# np.digitize()の仕様から、7分割して左端、右端を取り除いたbinsを作る | |
# e.g. | |
# [-2.4, 2.4] の場合、 | |
# >>> np.linspace(-2.4, 2.4, 6 + 1) | |
# array([-2.4, -1.6, -0.8, 0. , 0.8, 1.6, 2.4]) | |
# 6 + 1としているのは「6分割する場合仕切りは7つ」とでも覚えれば良さそう | |
# | |
# なお、np.digitize()の仕様に | |
# > Each index i returned is such that bins[i-1] <= x < bins[i] | |
# > If values in x are beyond the bounds of bins, 0 or len(bins) is | |
# > returned. | |
# -2.4〜-1.6においては0を、1.6〜2.4においては5を返したい | |
# よって左端と右端を削除するためスライスとして[1:-1]としている | |
return np.linspace(clip_min, clip_max, num + 1)[1:-1] | |
def digitize_state(observation): | |
# 各値をQ関数へ収められるように離散値に変換する | |
cart_pos, cart_v, pole_angle, pole_v = observation | |
digitized = [ | |
np.digitize(cart_pos, bins=bins(-2.4, 2.4, num_dizitized)), | |
np.digitize(cart_v, bins=bins(-3.0, 3.0, num_dizitized)), | |
np.digitize(pole_angle, bins=bins(-0.5, 0.5, num_dizitized)), | |
np.digitize(pole_v, bins=bins(-2.0, 2.0, num_dizitized)) | |
] | |
return sum([x * (num_dizitized**i) for i, x in enumerate(digitized)]) | |
# [2]行動a(t)を求める関数 ------------------------------------- | |
def get_action(next_state, episode, q_table): | |
# 徐々に最適行動のみをとる、ε-greedy法 | |
epsilon = 0.5 * (1 / (episode + 1)) | |
if epsilon <= np.random.uniform(0, 1): | |
next_action = np.argmax(q_table[next_state]) | |
else: | |
next_action = np.random.choice([0, 1]) | |
return next_action | |
# [3]Qテーブルを更新する関数 ------------------------------------- | |
def update_qtable(q_table, state, action, reward, next_state): | |
gamma = 0.99 | |
alpha = 0.5 | |
next_max_q = max(q_table[next_state][0], q_table[next_state][1]) | |
q_table[state, action] = ((1 - alpha) * q_table[state, action] | |
+ alpha * (reward + gamma * next_max_q)) | |
return q_table | |
def main(): | |
env = gym.make('CartPole-v0') | |
max_number_of_steps = 200 # 1試行のstep数 | |
num_consecutive_iterations = 100 # 学習完了評価に使用する平均試行回数 | |
num_episodes = 2000 # 総試行回数 | |
goal_average_reward = 195 # この報酬を超えると学習終了(中心への制御なし) | |
# Qテーブルを作成する。 | |
# env.action_space.n はとり得る操作の数を示し、CartPole-v0では2 | |
# q_tableは(6**4) x 2の2次元配列 | |
q_table = np.random.uniform( | |
low=-1, high=1, size=(num_dizitized**4, env.action_space.n)) | |
print(q_table) | |
# 各試行の報酬を格納 | |
total_reward_vec = np.zeros(num_consecutive_iterations) | |
# 学習後、各試行のt=200でのxの位置を格納 | |
final_x = np.zeros((num_episodes, 1)) | |
islearned = 0 # 学習が終わったフラグ | |
render_it = 0 # 描画フラグ | |
# [5] メインルーチン-------------------------------------------------- | |
for episode in range(num_episodes): # 試行数分繰り返す | |
# 環境の初期化 | |
observation = env.reset() | |
state = digitize_state(observation) | |
action = np.argmax(q_table[state]) | |
episode_reward = 0 | |
for t in range(max_number_of_steps): # 1試行のループ | |
if render_it: | |
env.render() | |
time.sleep(0.01) | |
# 行動a_tの実行により、s_{t+1}, r_{t}などを計算する | |
observation, reward, done, info = env.step(action) | |
# 報酬を設定し与える | |
if done: | |
if t < 195: | |
reward = -200 # こけたら罰則 | |
else: | |
reward = 1 # 立ったまま終了時は罰則はなし | |
else: | |
reward = 1 # 各ステップで立ってたら報酬追加 | |
episode_reward += reward # 報酬を追加 | |
# 離散状態s_{t+1}を求め、Q関数を更新する | |
# t+1での観測状態を、離散値に変換 | |
next_state = digitize_state(observation) | |
q_table = update_qtable(q_table, state, action, reward, next_state) | |
# 次の行動a_{t+1}を求める | |
action = get_action(next_state, episode, q_table) | |
state = next_state | |
# 終了時の処理 | |
if done: | |
cart_pos, cart_v, pole_angle, pole_v = observation | |
print('{} Episode finished after {} time steps' | |
'(cart_pos: {:.4}, pole_angle: {:.4},' | |
' total_reward mean: {})' | |
.format(episode, t + 1, | |
cart_pos, pole_angle, | |
total_reward_vec.mean())) | |
total_reward_vec = np.hstack((total_reward_vec[1:], | |
episode_reward)) # 報酬を記録 | |
if islearned == 1: # 学習終わってたら最終のx座標を格納 | |
final_x[episode, 0] = observation[0] | |
break | |
# 直近の100エピソードが規定報酬以上であれば成功 ← 嘘くさい | |
if (total_reward_vec.mean() >= goal_average_reward): | |
print('Episode %d train agent successfuly!' % episode) | |
islearned = 1 | |
# Qtableの保存する場合 | |
# np.savetxt('learned_Q_table.csv', q_table, delimiter=",") | |
if render_it == 0: | |
# 動画保存する場合 | |
# env = gym.wrappers.Monitor(env, | |
# './movie/cartpole-experiment-1') | |
render_it = 1 | |
# 10エピソードだけでどんな挙動になるのか見たかったら、 | |
# 以下のコメントを外す | |
# if episode>10: | |
# if render_it == 0: | |
# # 動画保存する場合 | |
# env = gym.wrappers.Monitor(env, | |
# './movie/cartpole-experiment-1') | |
# render_it = 1 | |
# islearned=1; | |
if islearned: | |
np.savetxt('final_x.csv', final_x, delimiter=",") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment