Skip to content

Instantly share code, notes, and snippets.

@cocomoff
Last active April 8, 2020 13:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cocomoff/1e708c2c3b47450f000a6d2176381593 to your computer and use it in GitHub Desktop.
Save cocomoff/1e708c2c3b47450f000a6d2176381593 to your computer and use it in GitHub Desktop.
Rust何も分からない練習帳
use rand::distributions::{Bernoulli, Distribution};
use rand::prelude::*;
// Armは以下の機能を持つ
// - 報酬を観察できる
trait Arm {
fn draw(&self) -> f64;
}
// ベルヌーイ分布に従うarm
#[derive(Debug)]
struct BernoulliArm {
p: f64,
}
impl Arm for BernoulliArm {
fn draw(&self) -> f64 {
let d = Bernoulli::new(self.p).unwrap();
let v = d.sample(&mut rand::thread_rng());
return if v { 1.0 } else { 0.0 };
}
}
//
// Algoは以下の機能を持つ
// - アルゴリズムを初期化する
// - 報酬を観察して内部状態を監視する
// - 行動を選択する
// - 内部状態を標準出力する
//
trait Algo {
fn new(eps: f64, n_arms: usize) -> Self;
fn update(&mut self, chosen_arm: usize, reward: f64);
fn select_arm(&self) -> usize;
fn print_state(&self);
}
// ε-greedy arm
struct EpsGreedy {
eps: f64,
n_arms: usize,
count: Vec<usize>,
value: Vec<f64>,
}
impl Algo for EpsGreedy {
// コンストラクタ (引いた回数と期待値の)
fn new(eps: f64, n_arms: usize) -> EpsGreedy {
return EpsGreedy {
eps: eps,
n_arms: n_arms,
count: vec![0; n_arms],
value: vec![0.0; n_arms],
};
}
// 中身出力
fn print_state(&self) {
println!("{:?}", self.eps);
println!("{:?}", self.count);
println!("{:?}", self.value);
}
// ε-greedy
fn select_arm(&self) -> usize {
let d = Bernoulli::new(self.eps).unwrap();
let v = d.sample(&mut rand::thread_rng());
let ans = if v {
let mut rng = thread_rng();
rng.gen_range(0, self.n_arms)
} else {
let mut i = 0;
let mut maxi = self.value[i];
for j in i + 1..self.n_arms {
if maxi < self.value[j] {
maxi = self.value[j];
i = j;
}
}
i
};
ans
}
fn update(&mut self, chosen_arm: usize, reward: f64) {
self.count[chosen_arm] += 1;
let n = self.count[chosen_arm] as f64;
let value = self.value[chosen_arm];
let new_value = ((n - 1f64) / n) * value + (1f64 / n) * reward;
self.value[chosen_arm] = new_value;
}
}
// 何も考えないアーム
struct Random {
n_arms: usize,
count: Vec<usize>,
value: Vec<f64>,
}
impl Algo for Random {
// コンストラクタ (引いた回数と期待値の)
fn new(eps: f64, n_arms: usize) -> Random {
return Random {
n_arms: n_arms,
count: vec![0; n_arms],
value: vec![0.0; n_arms],
};
}
// 中身出力
fn print_state(&self) {
println!("Random out of {:?} arms", self.n_arms);
println!("{:?}", self.count);
println!("{:?}", self.value);
}
// ε-greedy
fn select_arm(&self) -> usize {
let mut rng = thread_rng();
rng.gen_range(0, self.n_arms)
}
fn update(&mut self, chosen_arm: usize, reward: f64) {
self.count[chosen_arm] += 1;
let n = self.count[chosen_arm] as f64;
let value = self.value[chosen_arm];
let new_value = ((n - 1f64) / n) * value + (1f64 / n) * reward;
self.value[chosen_arm] = new_value;
}
}
fn main() {
// arm
let list_theta = vec![0.1, 0.1, 0.1, 0.1, 0.9];
let n_arms = list_theta.len();
let mut arms = Vec::new();
for theta in list_theta {
arms.push(BernoulliArm { p: theta });
}
// epsilon-greedy strategy
let t_horizon = 500;
let list_eps = vec![0.025, 0.05, 0.1, 0.2, 0.4];
for _eps in list_eps {
// EpsGreedy
let mut alg = EpsGreedy::new(_eps, n_arms);
for _t in 0..t_horizon {
let chosen_arm = alg.select_arm();
let reward = arms[chosen_arm].draw();
alg.update(chosen_arm, reward);
}
alg.print_state();
}
// randoms trategy (_eps is not used = 0.0)
let mut alg = Random::new(0.0f64, n_arms);
for _t in 0..t_horizon {
let chosen_arm = alg.select_arm();
let reward = arms[chosen_arm].draw();
alg.update(chosen_arm, reward);
}
alg.print_state();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment