Last active
March 23, 2018 06:03
-
-
Save sasaco/945b21f66e2633ef613eda0126333b35 to your computer and use it in GitHub Desktop.
機械学習の理論を理解しようとしてから オセロ AI を作ってみた 〜何これ Alpha Zero 編〜 ref: https://qiita.com/sasaco/items/d249ee3493b5b85c6eb5
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from reversi_zero.env.reversi_env import ReversiEnv, Player | |
def start(config: Config): | |
return OptimizeWorker(config).start() | |
class OptimizeWorker: | |
''' ↓ これが実際のスタート ''' | |
def start(self): | |
'''最善モデルをファイルからロードする''' | |
self.model = self.load_model() | |
'''モデルの学習する''' | |
self.training() | |
''' モデルの学習する ''' | |
def training(self): | |
''' 学習モデルの学習率(learning rate)やloss などのパラメータ設定 ''' | |
self.compile_model() | |
''' 対戦履歴を読み込む ''' | |
self.load_play_data() | |
while True: | |
''' 学習率を更新する ''' | |
self.update_learning_rate(total_steps) | |
''' 学習率を更新する ''' | |
steps = self.train_epoch(self.config.trainer.epoch_to_checkpoint) | |
total_steps += steps | |
if last_save_step + self.config.trainer.save_model_steps < total_steps: | |
self.save_current_model() | |
last_save_step = total_steps | |
if last_load_data_step + self.config.trainer.load_data_steps < total_steps: | |
self.load_play_data() | |
last_load_data_step = total_steps | |
'''最善モデルをファイルからロードする''' | |
def load_model(self): | |
from reversi_zero.agent.model import ReversiModel | |
model = ReversiModel(self.config) | |
rc = self.config.resource | |
dirs = get_next_generation_model_dirs(rc) | |
if not dirs: | |
logger.debug(f"loading best model") | |
if not load_best_model_weight(model): | |
raise RuntimeError(f"Best model can not loaded!") | |
else: | |
latest_dir = dirs[-1] | |
logger.debug(f"loading latest model") | |
config_path = os.path.join(latest_dir, rc.next_generation_model_config_filename) | |
weight_path = os.path.join(latest_dir, rc.next_generation_model_weight_filename) | |
model.load(config_path, weight_path) | |
return model | |
''' 学習モデルの学習率(learning rate)やloss などのパラメータ設定 ''' | |
def compile_model(self): | |
self.optimizer = SGD(lr=1e-2, momentum=0.9) | |
losses = [objective_function_for_policy, objective_function_for_value] | |
self.model.model.compile(optimizer=self.optimizer, loss=losses) | |
''' 対戦履歴を読み込む ''' | |
def load_play_data(self): | |
filenames = get_game_data_filenames(self.config.resource) | |
updated = False | |
for filename in filenames: | |
if filename in self.loaded_filenames: | |
continue | |
self.load_data_from_file(filename) | |
updated = True | |
for filename in (self.loaded_filenames - set(filenames)): | |
self.unload_data_of_file(filename) | |
updated = True | |
if updated: | |
logger.debug("updating training dataset") | |
self.dataset = self.collect_all_loaded_data() | |
''' 学習率を更新する ''' | |
def update_learning_rate(self, total_steps): | |
# The deepmind paper says | |
# ~400k: 1e-2 | |
# 400k~600k: 1e-3 | |
# 600k~: 1e-4 | |
if total_steps < 100000: | |
lr = 1e-2 | |
elif total_steps < 500000: | |
lr = 1e-3 | |
elif total_steps < 900000: | |
lr = 1e-4 | |
else: | |
lr = 2.5e-5 # means (1e-4 / 4): the paper batch size=2048, ours is 512. | |
K.set_value(self.optimizer.lr, lr) | |
''' 学習率を更新する ''' | |
def train_epoch(self, epochs): | |
tc = self.config.trainer | |
state_ary, policy_ary, z_ary = self.dataset | |
''' agent\model\class ReversiModel より from keras.engine.training の .fit ''' | |
self.model.model.fit(state_ary, [policy_ary, z_ary], | |
batch_size=tc.batch_size, | |
epochs=epochs) | |
steps = (state_ary.shape[0] // tc.batch_size) * epochs | |
return steps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from reversi_zero.env.reversi_env import ReversiEnv, Player | |
def start(config: Config): | |
return SelfPlayWorker(config, env=ReversiEnv()).start() | |
class SelfPlayWorker: | |
''' ↓ これが実際のスタート ''' | |
def start(self): | |
if self.model is None: | |
'''もし、最善モデルが存在しない場合は、ランダムな初期モデルを作成し、それを最善モデルとします。''' | |
self.model = self.load_model() | |
''' 無限に繰り返す ''' | |
while True: | |
''' 自己対戦を開始する ''' | |
env = self.start_game(idx) | |
''' 自己対戦を開始する ''' | |
def start_game(self, idx): | |
''' 盤面をリセットする ''' | |
self.env.reset() | |
''' 2人のAIを呼び出す ''' | |
self.black = ReversiPlayer(self.config, self.model, enable_resign=enable_resign) | |
self.white = ReversiPlayer(self.config, self.model, enable_resign=enable_resign) | |
''' 対戦が終了するまで繰り返す ''' | |
while not self.env.done: | |
''' AIに次の手を選択させる ''' | |
if self.env.next_player == Player.black: | |
action = self.black.action(observation.black, observation.white) | |
else: | |
action = self.white.action(observation.white, observation.black) | |
observation, info = self.env.step(action) | |
''' 対戦終了,対戦履歴を保存する ''' | |
self.save_play_data(write=idx % self.config.play_data.nb_game_in_file == 0) | |
return self.env | |
''' 対戦履歴を保存する ''' | |
def save_play_data(self, write=True): | |
''' 黒の手と白の手をまとめる ''' | |
data = self.black.moves + self.white.moves | |
''' バッファに追加する ''' | |
self.buffer += data | |
''' ファイルに保存する ''' | |
write_game_data_to_file(path, self.buffer) | |
''' バッファをクリアする ''' | |
self.buffer = [] | |
'''最善モデルをファイルからロードする''' | |
def load_model(self): | |
from reversi_zero.agent.model import ReversiModel | |
model = ReversiModel(self.config) | |
if self.config.opts.new or not load_best_model_weight(model): | |
model.build() | |
save_as_best_model(model) | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment