Skip to content

Instantly share code, notes, and snippets.

@eTakazawa
Created December 30, 2017 15:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eTakazawa/4f3cdbf4578db9ecbb05b4e773639d4d to your computer and use it in GitHub Desktop.
Save eTakazawa/4f3cdbf4578db9ecbb05b4e773639d4d to your computer and use it in GitHub Desktop.
強化学習で迷路を解く
#include <iostream>
#include <tuple>
#include <random>
#include <vector>
#include <climits>
#include <chrono>
#include <thread>
class QFuction {
public:
std::vector<std::vector<double>> qtable_;
std::vector<std::vector<int>> try_count_;
QFuction(int num_states, int num_actions) {
num_states_ = num_states;
num_actions_ = num_actions;
qtable_ = std::vector<std::vector<double>>(num_states_, std::vector<double>(num_actions_));
try_count_ = std::vector<std::vector<int>>(num_states_, std::vector<int>(num_actions_));
}
double Call(int state, int action) {
return qtable_[state][action];
}
std::pair<double, int> GetMaxQValueAndAction(int state) {
double max_q_value = -1e18;
int max_action = -1;
for (int action = 0; action < num_actions_; action++) {
double q_value = Call(state, action);
if (max_q_value < q_value) {
max_q_value = q_value;
max_action = action;
}
}
return std::make_pair(max_q_value, max_action);
}
void Update(int state, int action, double value) {
try_count_[state][action]++;
qtable_[state][action] = value;
}
double GetUCB1(int state, int action) {
const static double C = 1.0;
int n = 1, nj = try_count_[state][action] + 1;
for (int a = 0; a < num_actions_; a++) {
n += try_count_[state][a];
}
return qtable_[state][action] + C * sqrt(2 * log(n) / nj);
}
int GetMaxUCB1Action(int state) {
double max_ucb1 = -1e18;
int max_action = -1;
for (int action = 0; action < num_actions_; action++) {
double ucb1 = GetUCB1(state, action);
if (max_ucb1 < ucb1) {
max_ucb1 = ucb1;
max_action = action;
}
}
return max_action;
}
int num_states_, num_actions_;
};
class MazeEnviroment {
public:
enum Action { LEFT, DOWN, RIGHT, UP, ACTION_SIZE };
enum Cell { EMPTY, WALL, START, GOAL, CELL_SIZE };
MazeEnviroment() {
std::cin >> height_ >> width_;
orgmaze_ = std::vector<std::vector<int>>(height_, std::vector<int>(width_));
int cell;
for (int h = 0; h < height_; h++) {
for (int w = 0; w < width_; w++) {
std::cin >> cell;
orgmaze_[h][w] = cell;
if (cell == START) {
x_ = start_x_ = w;
y_ = start_y_ = h;
} else if (cell == GOAL) {
goal_x_ = w;
goal_y_ = h;
}
}
}
maze_ = orgmaze_;
isdone_ = isdead_ = isgoal_ = false;
}
void Reset() {
x_ = start_x_;
y_ = start_y_;
maze_ = orgmaze_;
isdone_ = isdead_ = isgoal_ = false;
}
void Move(int action) {
x_ = x_ + dx[action];
y_ = y_ + dy[action];
int next_cell_ = maze_[y_][x_];
if (next_cell_ == WALL) Dead();
else if (next_cell_ == GOAL) Goal();
else Empty();
}
void Empty() {
reward_ = 0.0;
}
void Dead() {
isdone_ = isdead_ = true;
reward_ = -1.0;
}
void Goal() {
isdone_ = isgoal_ = true;
reward_ = 1.0;
}
int GetState() {
return y_ * width_ + x_;
}
int GetNumStates() {
return width_ * height_;
}
int GetNumActions() {
return ACTION_SIZE;
}
void Show() {
for (int h = 0; h < height_; h++) {
for (int w = 0; w < width_; w++) {
if (h == y_ && w == x_) std::cout << "x ";
else std::cout << maze_[h][w] << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}
int height_, width_;
std::vector<std::vector<int>> orgmaze_, maze_;
const int dx[4] = {-1, 0, 1, 0};
const int dy[4] = {0, 1, 0, -1};
int x_, y_, start_x_, start_y_, goal_x_, goal_y_;
bool isdone_, isdead_, isgoal_;
double reward_;
};
int main(void) {
MazeEnviroment env;
QFuction qfunc(env.GetNumStates(), env.GetNumActions());
std::mt19937 rnd_engine(0);
int num_episodes = 10000;
double epsilon = 0.1; // ε-greedy
double alpha = 1.0; // 学習率
double gamma = 0.95; // 報酬の割引率
for (int m = 0; m < num_episodes; m++) {
env.Reset();
int t = 0;
std::cerr << "episode : " << m << std::endl;
while (!env.isdone_) {
std::cerr << "\tturn : " << t << std::endl;
/* ---------- 現在の状態の記録 ---------- */
int state = env.GetState();
std::cerr << "\t\tstate : " << env.x_ << " " << env.y_ << " / " << env.GetState() << std::endl;
/* ---------- 政策の生成 ε-greedy ---------- */
// int max_action;
// max_action = qfunc.GetMaxQValueAndAction(env.GetState()).second;
// double prob_epsilon = epsilon / env.GetNumActions();
// std::vector<double> policy(env.GetNumActions(), prob_epsilon);
// policy[max_action] = 1 - epsilon + prob_epsilon;
// std::discrete_distribution<> policy_prob(policy.begin(), policy.end()); // 確率に基づき選択
// int action = policy_prob(rnd_engine); //https://cpplover.blogspot.jp/2011/03/blog-post_07.html
/* ---------- 政策の生成 UCB1 ---------- */
int action = qfunc.GetMaxUCB1Action(state);
/* ---------- 政策を基にして行動 ---------- */
env.Move(action); // 政策に基づき行動
int next_state = env.GetState();
std::cerr << "\t\taction : " << action << std::endl;
/* ---------- Qテーブルの更新 ---------- */
if (t > 0) {
// 1ステップ前の(状態,行動)のQ値を更新
double gradient = env.reward_ - qfunc.Call(state, action)
+ gamma * qfunc.GetMaxQValueAndAction(next_state).first;
double q_value = qfunc.Call(state, action) + alpha * gradient;
qfunc.Update(state, action, q_value);
std::cerr << "\t\tstate,action : " << state << "," << action << std::endl;
std::cerr << "\t\tqvalue : " << q_value << std::endl;
} else {
if (env.isdone_) { // 1ステップで終了した場合
qfunc.Update( state, action, env.reward_);
}
}
t++;
}
}
std::cerr << "Test" << std::endl;
env.Reset();
epsilon = 0.0;
while (!env.isdone_) {
int state = env.GetState();
std::cerr << env.x_ << " " << env.y_ << std::endl;
/* ---------- 政策の生成 ε-greedy ---------- */
/* ε=0なので,決定的 */
double max_q_value;
int max_action;
std::tie(max_q_value, max_action) = qfunc.GetMaxQValueAndAction(state);
double prob_epsilon = epsilon / env.GetNumActions();
std::vector<double> policy(env.GetNumActions(), prob_epsilon);
policy[max_action] = 1 - epsilon + prob_epsilon;
std::discrete_distribution<> policy_prob(policy.begin(), policy.end()); // 確率に基づき選択
int action = policy_prob(rnd_engine); //https://cpplover.blogspot.jp/2011/03/blog-post_07.html
/* ---------- 現在の状態の表示 ---------- */
std::cerr << "Q-Value" << std::endl;
for (int a = 0; a < env.GetNumActions(); a++) {
std::cerr << qfunc.Call(state, a) << std::endl;;
}
std::cerr << "action : " << action << std::endl;
env.Show();
/* ---------- 政策を基にして行動 ---------- */
env.Move(action); // 政策に基づき行動
std::this_thread::sleep_for(std::chrono::milliseconds(300));
}
return 0;
}
9 13
1 1 1 1 1 1 1 1 1 1 1 1 1
1 2 1 0 0 0 0 0 0 0 0 0 1
1 0 1 0 1 1 1 1 1 1 1 0 1
1 0 0 0 0 0 0 0 0 0 1 0 1
1 1 1 1 0 1 0 1 1 0 1 0 1
1 0 0 0 0 1 0 0 1 0 1 1 1
1 0 1 1 1 1 0 1 1 0 1 3 1
1 0 0 0 0 1 0 0 1 0 0 0 1
1 1 1 1 1 1 1 1 1 1 1 1 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment