Last active
April 8, 2020 13:53
-
-
Save cocomoff/1e708c2c3b47450f000a6d2176381593 to your computer and use it in GitHub Desktop.
Rust何も分からない練習帳
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use rand::distributions::{Bernoulli, Distribution}; | |
use rand::prelude::*; | |
// Armは以下の機能を持つ | |
// - 報酬を観察できる | |
trait Arm { | |
fn draw(&self) -> f64; | |
} | |
// ベルヌーイ分布に従うarm | |
#[derive(Debug)] | |
struct BernoulliArm { | |
p: f64, | |
} | |
impl Arm for BernoulliArm { | |
fn draw(&self) -> f64 { | |
let d = Bernoulli::new(self.p).unwrap(); | |
let v = d.sample(&mut rand::thread_rng()); | |
return if v { 1.0 } else { 0.0 }; | |
} | |
} | |
// | |
// Algoは以下の機能を持つ | |
// - アルゴリズムを初期化する | |
// - 報酬を観察して内部状態を監視する | |
// - 行動を選択する | |
// - 内部状態を標準出力する | |
// | |
trait Algo { | |
fn new(eps: f64, n_arms: usize) -> Self; | |
fn update(&mut self, chosen_arm: usize, reward: f64); | |
fn select_arm(&self) -> usize; | |
fn print_state(&self); | |
} | |
// ε-greedy arm | |
struct EpsGreedy { | |
eps: f64, | |
n_arms: usize, | |
count: Vec<usize>, | |
value: Vec<f64>, | |
} | |
impl Algo for EpsGreedy { | |
// コンストラクタ (引いた回数と期待値の) | |
fn new(eps: f64, n_arms: usize) -> EpsGreedy { | |
return EpsGreedy { | |
eps: eps, | |
n_arms: n_arms, | |
count: vec![0; n_arms], | |
value: vec![0.0; n_arms], | |
}; | |
} | |
// 中身出力 | |
fn print_state(&self) { | |
println!("{:?}", self.eps); | |
println!("{:?}", self.count); | |
println!("{:?}", self.value); | |
} | |
// ε-greedy | |
fn select_arm(&self) -> usize { | |
let d = Bernoulli::new(self.eps).unwrap(); | |
let v = d.sample(&mut rand::thread_rng()); | |
let ans = if v { | |
let mut rng = thread_rng(); | |
rng.gen_range(0, self.n_arms) | |
} else { | |
let mut i = 0; | |
let mut maxi = self.value[i]; | |
for j in i + 1..self.n_arms { | |
if maxi < self.value[j] { | |
maxi = self.value[j]; | |
i = j; | |
} | |
} | |
i | |
}; | |
ans | |
} | |
fn update(&mut self, chosen_arm: usize, reward: f64) { | |
self.count[chosen_arm] += 1; | |
let n = self.count[chosen_arm] as f64; | |
let value = self.value[chosen_arm]; | |
let new_value = ((n - 1f64) / n) * value + (1f64 / n) * reward; | |
self.value[chosen_arm] = new_value; | |
} | |
} | |
// 何も考えないアーム | |
struct Random { | |
n_arms: usize, | |
count: Vec<usize>, | |
value: Vec<f64>, | |
} | |
impl Algo for Random { | |
// コンストラクタ (引いた回数と期待値の) | |
fn new(eps: f64, n_arms: usize) -> Random { | |
return Random { | |
n_arms: n_arms, | |
count: vec![0; n_arms], | |
value: vec![0.0; n_arms], | |
}; | |
} | |
// 中身出力 | |
fn print_state(&self) { | |
println!("Random out of {:?} arms", self.n_arms); | |
println!("{:?}", self.count); | |
println!("{:?}", self.value); | |
} | |
// ε-greedy | |
fn select_arm(&self) -> usize { | |
let mut rng = thread_rng(); | |
rng.gen_range(0, self.n_arms) | |
} | |
fn update(&mut self, chosen_arm: usize, reward: f64) { | |
self.count[chosen_arm] += 1; | |
let n = self.count[chosen_arm] as f64; | |
let value = self.value[chosen_arm]; | |
let new_value = ((n - 1f64) / n) * value + (1f64 / n) * reward; | |
self.value[chosen_arm] = new_value; | |
} | |
} | |
fn main() { | |
// arm | |
let list_theta = vec![0.1, 0.1, 0.1, 0.1, 0.9]; | |
let n_arms = list_theta.len(); | |
let mut arms = Vec::new(); | |
for theta in list_theta { | |
arms.push(BernoulliArm { p: theta }); | |
} | |
// epsilon-greedy strategy | |
let t_horizon = 500; | |
let list_eps = vec![0.025, 0.05, 0.1, 0.2, 0.4]; | |
for _eps in list_eps { | |
// EpsGreedy | |
let mut alg = EpsGreedy::new(_eps, n_arms); | |
for _t in 0..t_horizon { | |
let chosen_arm = alg.select_arm(); | |
let reward = arms[chosen_arm].draw(); | |
alg.update(chosen_arm, reward); | |
} | |
alg.print_state(); | |
} | |
// randoms trategy (_eps is not used = 0.0) | |
let mut alg = Random::new(0.0f64, n_arms); | |
for _t in 0..t_horizon { | |
let chosen_arm = alg.select_arm(); | |
let reward = arms[chosen_arm].draw(); | |
alg.update(chosen_arm, reward); | |
} | |
alg.print_state(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment