Skip to content

Instantly share code, notes, and snippets.

@3t14
Created January 23, 2022 22:48
Show Gist options
  • Save 3t14/06bdda9b7edeba462f6b7dceb31cee96 to your computer and use it in GitHub Desktop.
Save 3t14/06bdda9b7edeba462f6b7dceb31cee96 to your computer and use it in GitHub Desktop.
TicTacToeDPAgent: マルバツゲームのDPによるエージェント実装
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
class DPAgent:
def __init__(self, env, gamma=0.9):
self.env = env
self.gamma = gamma
self.reset()
# 価値関数、方策を初期化する
def reset(self):
self.V = {} # 価値関数の推定値
self.pi = {} # 方策
# piのそれぞれの行動確率は均等確率
actions_len = 0
for i in self.env.all_actions():
actions_len += 1
prob = 1.0 / actions_len
for state in env.all_states():
self.V[state] = 0.0
self.pi[state] = {}
for action in env.all_actions():
self.pi[state][action] = prob
# 1ステップ分の評価と価値関数の更新
def eval_step(self):
delta = 0
count = 0
for state in self.env.all_states():
if self.env.rewards[state] == None: # Noneの時は処理しない
self.V[state] = None
continue
# 特定の状態におけるそれぞれの行動選択確率
pi_actions = self.pi[state]
V_next_s = 0.0 # Vk+1(s)
# それぞれの行動のインデックスと確率を取得
count += 1
if count % 10000 == 0:
print(count)
for action, pi_action in pi_actions.items():
# print(state, action)
s_dash = self.env.next_state(state, action) # s_dash = s' = next_state
if s_dash == None:
continue
r = self.env.rewards[s_dash]
if r != None: # Noneの時は処理しない
V_next_s += pi_action * (r + self.gamma * self.V[s_dash])
delta = max(delta, abs(self.V[state] - V_next_s))
self.V[state] = V_next_s
return delta
# 一定値まで収束するまでVを更新する
def eval_policy(self, threshold=0.00001):
count = 0
while True:
# deltaはeval_stepした時の最大差分値
delta = self.eval_step()
count += 1
if count % 100 == 0:
print("{0}: delta = {1}".format(count, delta))
if delta < threshold:
print("{0}: delta = {1}".format(count, delta))
break
return
# 現在の価値関数をヒートマップで視覚化
def render_V(self):
self.env.render()
# # 2次元のnumpy配列に変換し描画
# plt.figure() # 前の図に描画しないようにするため次の描画先を作成
# data = np.zeros(shape=(self.env.n, self.env.m))
# for pos in self.env.all_states():
# data[pos[1], pos[0]] = self.V[pos]
# sns.heatmap(data, annot=True, fmt="1.2f")
# plt.tick_params(labeltop=True, labelbottom=False)
#### 方策改善用のコード ####
# 辞書型のデータから最大値をとるキー(index)を返す
def argmax(self, dict):
max_value = float('-inf')
max_index = 0
for index, value in dict.items():
if value > max_value:
max_value = value
max_index = index
return max_index
# それぞれの状態piについて、グリーディな方策を選択
def select_greedy_policy(self):
pi_next = {}
# それぞれの状態から全ての状態に遷移した時の
# 最大の価値となる状態を選出し、その方策を1.0に変更
# それ以外の方策を0.0に変更する
for state in self.env.all_states():
max_action = (0, 0) # 最大値を取るaction_index
max_value = float('-inf') # 最大値。初期値は最小値
q_s = {}
for action in self.env.all_actions():
next_state = self.env.next_state(state, action)
if next_state == None:
continue
r = self.env.rewards[next_state]
if r != None:
q_s[action] = r + self.gamma * self.V[next_state]
max_action = self.argmax(q_s)
pi_next[state] = {}
# argmaxをとる行動のみ1.0、それ以外は0.0に設定
for index in self.env.all_actions():
pi_next[state][index] = 1.0 if index == max_action else 0.0
return pi_next
# ポリシーを繰り返し、評価・改善する
def iter_policy(self):
while True:
# 方策評価
self.eval_policy()
# 新規方策
pi_next = self.select_greedy_policy()
self.render_V()
print("pi: ", self.pi)
if self.pi == pi_next:
break
self.pi = pi_next
return
# 自動プレイ
def auto_play(self, max_steps=100):
# 方策評価後を前提
# 環境と対話しながら、エージェントを目的地へ移動させる
step = 0
# Player 2 = ランダムに手を打つ
# Player 1 = 最適な手を打つ
done = False
actions = []
for action in self.env.all_actions():
actions.append(action)
env.render()
while step <= max_steps and done == False:
# Player 1
action = self.argmax(self.pi[self.env.agent_state])
state_next = None
while state_next == None and done == False:
state_next, reward, done, _ = self.env.step(action)
print("Step {0}: Player 1 action = {1}, reward = {2}".format(step, action, reward))
env.render()
if done:
print("Player1: won")
break
# Player 2
action = actions[random.randint(0, len(actions) - 1)]
state_next = None
while state_next == None and done == False:
state_next, reward, done, _ = self.env.step(action)
print("Step {0}: Player 2 action = {1}, reward = {2}".format(step, action, reward))
env.render()
if done:
print("Player2: won")
step += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment