Created
December 30, 2017 15:16
-
-
Save eTakazawa/4f3cdbf4578db9ecbb05b4e773639d4d to your computer and use it in GitHub Desktop.
強化学習で迷路を解く
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <tuple> | |
#include <random> | |
#include <vector> | |
#include <climits> | |
#include <chrono> | |
#include <thread> | |
class QFuction { | |
public: | |
std::vector<std::vector<double>> qtable_; | |
std::vector<std::vector<int>> try_count_; | |
QFuction(int num_states, int num_actions) { | |
num_states_ = num_states; | |
num_actions_ = num_actions; | |
qtable_ = std::vector<std::vector<double>>(num_states_, std::vector<double>(num_actions_)); | |
try_count_ = std::vector<std::vector<int>>(num_states_, std::vector<int>(num_actions_)); | |
} | |
double Call(int state, int action) { | |
return qtable_[state][action]; | |
} | |
std::pair<double, int> GetMaxQValueAndAction(int state) { | |
double max_q_value = -1e18; | |
int max_action = -1; | |
for (int action = 0; action < num_actions_; action++) { | |
double q_value = Call(state, action); | |
if (max_q_value < q_value) { | |
max_q_value = q_value; | |
max_action = action; | |
} | |
} | |
return std::make_pair(max_q_value, max_action); | |
} | |
void Update(int state, int action, double value) { | |
try_count_[state][action]++; | |
qtable_[state][action] = value; | |
} | |
double GetUCB1(int state, int action) { | |
const static double C = 1.0; | |
int n = 1, nj = try_count_[state][action] + 1; | |
for (int a = 0; a < num_actions_; a++) { | |
n += try_count_[state][a]; | |
} | |
return qtable_[state][action] + C * sqrt(2 * log(n) / nj); | |
} | |
int GetMaxUCB1Action(int state) { | |
double max_ucb1 = -1e18; | |
int max_action = -1; | |
for (int action = 0; action < num_actions_; action++) { | |
double ucb1 = GetUCB1(state, action); | |
if (max_ucb1 < ucb1) { | |
max_ucb1 = ucb1; | |
max_action = action; | |
} | |
} | |
return max_action; | |
} | |
int num_states_, num_actions_; | |
}; | |
class MazeEnviroment { | |
public: | |
enum Action { LEFT, DOWN, RIGHT, UP, ACTION_SIZE }; | |
enum Cell { EMPTY, WALL, START, GOAL, CELL_SIZE }; | |
MazeEnviroment() { | |
std::cin >> height_ >> width_; | |
orgmaze_ = std::vector<std::vector<int>>(height_, std::vector<int>(width_)); | |
int cell; | |
for (int h = 0; h < height_; h++) { | |
for (int w = 0; w < width_; w++) { | |
std::cin >> cell; | |
orgmaze_[h][w] = cell; | |
if (cell == START) { | |
x_ = start_x_ = w; | |
y_ = start_y_ = h; | |
} else if (cell == GOAL) { | |
goal_x_ = w; | |
goal_y_ = h; | |
} | |
} | |
} | |
maze_ = orgmaze_; | |
isdone_ = isdead_ = isgoal_ = false; | |
} | |
void Reset() { | |
x_ = start_x_; | |
y_ = start_y_; | |
maze_ = orgmaze_; | |
isdone_ = isdead_ = isgoal_ = false; | |
} | |
void Move(int action) { | |
x_ = x_ + dx[action]; | |
y_ = y_ + dy[action]; | |
int next_cell_ = maze_[y_][x_]; | |
if (next_cell_ == WALL) Dead(); | |
else if (next_cell_ == GOAL) Goal(); | |
else Empty(); | |
} | |
void Empty() { | |
reward_ = 0.0; | |
} | |
void Dead() { | |
isdone_ = isdead_ = true; | |
reward_ = -1.0; | |
} | |
void Goal() { | |
isdone_ = isgoal_ = true; | |
reward_ = 1.0; | |
} | |
int GetState() { | |
return y_ * width_ + x_; | |
} | |
int GetNumStates() { | |
return width_ * height_; | |
} | |
int GetNumActions() { | |
return ACTION_SIZE; | |
} | |
void Show() { | |
for (int h = 0; h < height_; h++) { | |
for (int w = 0; w < width_; w++) { | |
if (h == y_ && w == x_) std::cout << "x "; | |
else std::cout << maze_[h][w] << " "; | |
} | |
std::cout << std::endl; | |
} | |
std::cout << std::endl; | |
} | |
int height_, width_; | |
std::vector<std::vector<int>> orgmaze_, maze_; | |
const int dx[4] = {-1, 0, 1, 0}; | |
const int dy[4] = {0, 1, 0, -1}; | |
int x_, y_, start_x_, start_y_, goal_x_, goal_y_; | |
bool isdone_, isdead_, isgoal_; | |
double reward_; | |
}; | |
int main(void) { | |
MazeEnviroment env; | |
QFuction qfunc(env.GetNumStates(), env.GetNumActions()); | |
std::mt19937 rnd_engine(0); | |
int num_episodes = 10000; | |
double epsilon = 0.1; // ε-greedy | |
double alpha = 1.0; // 学習率 | |
double gamma = 0.95; // 報酬の割引率 | |
for (int m = 0; m < num_episodes; m++) { | |
env.Reset(); | |
int t = 0; | |
std::cerr << "episode : " << m << std::endl; | |
while (!env.isdone_) { | |
std::cerr << "\tturn : " << t << std::endl; | |
/* ---------- 現在の状態の記録 ---------- */ | |
int state = env.GetState(); | |
std::cerr << "\t\tstate : " << env.x_ << " " << env.y_ << " / " << env.GetState() << std::endl; | |
/* ---------- 政策の生成 ε-greedy ---------- */ | |
// int max_action; | |
// max_action = qfunc.GetMaxQValueAndAction(env.GetState()).second; | |
// double prob_epsilon = epsilon / env.GetNumActions(); | |
// std::vector<double> policy(env.GetNumActions(), prob_epsilon); | |
// policy[max_action] = 1 - epsilon + prob_epsilon; | |
// std::discrete_distribution<> policy_prob(policy.begin(), policy.end()); // 確率に基づき選択 | |
// int action = policy_prob(rnd_engine); //https://cpplover.blogspot.jp/2011/03/blog-post_07.html | |
/* ---------- 政策の生成 UCB1 ---------- */ | |
int action = qfunc.GetMaxUCB1Action(state); | |
/* ---------- 政策を基にして行動 ---------- */ | |
env.Move(action); // 政策に基づき行動 | |
int next_state = env.GetState(); | |
std::cerr << "\t\taction : " << action << std::endl; | |
/* ---------- Qテーブルの更新 ---------- */ | |
if (t > 0) { | |
// 1ステップ前の(状態,行動)のQ値を更新 | |
double gradient = env.reward_ - qfunc.Call(state, action) | |
+ gamma * qfunc.GetMaxQValueAndAction(next_state).first; | |
double q_value = qfunc.Call(state, action) + alpha * gradient; | |
qfunc.Update(state, action, q_value); | |
std::cerr << "\t\tstate,action : " << state << "," << action << std::endl; | |
std::cerr << "\t\tqvalue : " << q_value << std::endl; | |
} else { | |
if (env.isdone_) { // 1ステップで終了した場合 | |
qfunc.Update( state, action, env.reward_); | |
} | |
} | |
t++; | |
} | |
} | |
std::cerr << "Test" << std::endl; | |
env.Reset(); | |
epsilon = 0.0; | |
while (!env.isdone_) { | |
int state = env.GetState(); | |
std::cerr << env.x_ << " " << env.y_ << std::endl; | |
/* ---------- 政策の生成 ε-greedy ---------- */ | |
/* ε=0なので,決定的 */ | |
double max_q_value; | |
int max_action; | |
std::tie(max_q_value, max_action) = qfunc.GetMaxQValueAndAction(state); | |
double prob_epsilon = epsilon / env.GetNumActions(); | |
std::vector<double> policy(env.GetNumActions(), prob_epsilon); | |
policy[max_action] = 1 - epsilon + prob_epsilon; | |
std::discrete_distribution<> policy_prob(policy.begin(), policy.end()); // 確率に基づき選択 | |
int action = policy_prob(rnd_engine); //https://cpplover.blogspot.jp/2011/03/blog-post_07.html | |
/* ---------- 現在の状態の表示 ---------- */ | |
std::cerr << "Q-Value" << std::endl; | |
for (int a = 0; a < env.GetNumActions(); a++) { | |
std::cerr << qfunc.Call(state, a) << std::endl;; | |
} | |
std::cerr << "action : " << action << std::endl; | |
env.Show(); | |
/* ---------- 政策を基にして行動 ---------- */ | |
env.Move(action); // 政策に基づき行動 | |
std::this_thread::sleep_for(std::chrono::milliseconds(300)); | |
} | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
9 13 | |
1 1 1 1 1 1 1 1 1 1 1 1 1 | |
1 2 1 0 0 0 0 0 0 0 0 0 1 | |
1 0 1 0 1 1 1 1 1 1 1 0 1 | |
1 0 0 0 0 0 0 0 0 0 1 0 1 | |
1 1 1 1 0 1 0 1 1 0 1 0 1 | |
1 0 0 0 0 1 0 0 1 0 1 1 1 | |
1 0 1 1 1 1 0 1 1 0 1 3 1 | |
1 0 0 0 0 1 0 0 1 0 0 0 1 | |
1 1 1 1 1 1 1 1 1 1 1 1 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment