Created
January 23, 2022 22:46
-
-
Save 3t14/045e63d44fa445622a7bc9dd9324ea65 to your computer and use it in GitHub Desktop.
TicTacToeWorld: マルバツゲームの環境クラス
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
class DPAgent: | |
def __init__(self, env, gamma=0.9): | |
self.env = env | |
self.gamma = gamma | |
self.reset() | |
# 価値関数、方策を初期化する | |
def reset(self): | |
self.V = {} # 価値関数の推定値 | |
self.pi = {} # 方策 | |
# piのそれぞれの行動確率は均等確率 | |
actions_len = 0 | |
for i in self.env.all_actions(): | |
actions_len += 1 | |
prob = 1.0 / actions_len | |
for state in env.all_states(): | |
self.V[state] = 0.0 | |
self.pi[state] = {} | |
for action in env.all_actions(): | |
self.pi[state][action] = prob | |
# 1ステップ分の評価と価値関数の更新 | |
def eval_step(self): | |
delta = 0 | |
count = 0 | |
for state in self.env.all_states(): | |
if self.env.rewards[state] == None: # Noneの時は処理しない | |
self.V[state] = None | |
continue | |
# 特定の状態におけるそれぞれの行動選択確率 | |
pi_actions = self.pi[state] | |
V_next_s = 0.0 # Vk+1(s) | |
# それぞれの行動のインデックスと確率を取得 | |
count += 1 | |
if count % 10000 == 0: | |
print(count) | |
for action, pi_action in pi_actions.items(): | |
# print(state, action) | |
s_dash = self.env.next_state(state, action) # s_dash = s' = next_state | |
if s_dash == None: | |
continue | |
r = self.env.rewards[s_dash] | |
if r != None: # Noneの時は処理しない | |
V_next_s += pi_action * (r + self.gamma * self.V[s_dash]) | |
delta = max(delta, abs(self.V[state] - V_next_s)) | |
self.V[state] = V_next_s | |
return delta | |
# 一定値まで収束するまでVを更新する | |
def eval_policy(self, threshold=0.00001): | |
count = 0 | |
while True: | |
# deltaはeval_stepした時の最大差分値 | |
delta = self.eval_step() | |
count += 1 | |
if count % 100 == 0: | |
print("{0}: delta = {1}".format(count, delta)) | |
if delta < threshold: | |
print("{0}: delta = {1}".format(count, delta)) | |
break | |
return | |
# 現在の価値関数をヒートマップで視覚化 | |
def render_V(self): | |
self.env.render() | |
# # 2次元のnumpy配列に変換し描画 | |
# plt.figure() # 前の図に描画しないようにするため次の描画先を作成 | |
# data = np.zeros(shape=(self.env.n, self.env.m)) | |
# for pos in self.env.all_states(): | |
# data[pos[1], pos[0]] = self.V[pos] | |
# sns.heatmap(data, annot=True, fmt="1.2f") | |
# plt.tick_params(labeltop=True, labelbottom=False) | |
#### 方策改善用のコード #### | |
# 辞書型のデータから最大値をとるキー(index)を返す | |
def argmax(self, dict): | |
max_value = float('-inf') | |
max_index = 0 | |
for index, value in dict.items(): | |
if value > max_value: | |
max_value = value | |
max_index = index | |
return max_index | |
# それぞれの状態piについて、グリーディな方策を選択 | |
def select_greedy_policy(self): | |
pi_next = {} | |
# それぞれの状態から全ての状態に遷移した時の | |
# 最大の価値となる状態を選出し、その方策を1.0に変更 | |
# それ以外の方策を0.0に変更する | |
for state in self.env.all_states(): | |
max_action = (0, 0) # 最大値を取るaction_index | |
max_value = float('-inf') # 最大値。初期値は最小値 | |
q_s = {} | |
for action in self.env.all_actions(): | |
next_state = self.env.next_state(state, action) | |
if next_state == None: | |
continue | |
r = self.env.rewards[next_state] | |
if r != None: | |
q_s[action] = r + self.gamma * self.V[next_state] | |
max_action = self.argmax(q_s) | |
pi_next[state] = {} | |
# argmaxをとる行動のみ1.0、それ以外は0.0に設定 | |
for index in self.env.all_actions(): | |
pi_next[state][index] = 1.0 if index == max_action else 0.0 | |
return pi_next | |
# ポリシーを繰り返し、評価・改善する | |
def iter_policy(self): | |
while True: | |
# 方策評価 | |
self.eval_policy() | |
# 新規方策 | |
pi_next = self.select_greedy_policy() | |
self.render_V() | |
print("pi: ", self.pi) | |
if self.pi == pi_next: | |
break | |
self.pi = pi_next | |
return | |
# 自動プレイ | |
def auto_play(self, max_steps=100): | |
# 方策評価後を前提 | |
# 環境と対話しながら、エージェントを目的地へ移動させる | |
step = 0 | |
# Player 2 = ランダムに手を打つ | |
# Player 1 = 最適な手を打つ | |
done = False | |
actions = [] | |
for action in self.env.all_actions(): | |
actions.append(action) | |
env.render() | |
while step <= max_steps and done == False: | |
# Player 1 | |
action = self.argmax(self.pi[self.env.agent_state]) | |
state_next = None | |
while state_next == None and done == False: | |
state_next, reward, done, _ = self.env.step(action) | |
print("Step {0}: Player 1 action = {1}, reward = {2}".format(step, action, reward)) | |
env.render() | |
if done: | |
print("Player1: won") | |
break | |
# Player 2 | |
action = actions[random.randint(0, len(actions) - 1)] | |
state_next = None | |
while state_next == None and done == False: | |
state_next, reward, done, _ = self.env.step(action) | |
print("Step {0}: Player 2 action = {1}, reward = {2}".format(step, action, reward)) | |
env.render() | |
if done: | |
print("Player2: won") | |
step += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment