Last active
September 6, 2020 15:45
-
-
Save yytyd/a454c4f6ac23d93897efd038c4e6816b to your computer and use it in GitHub Desktop.
minituna の実装を移行して理解する
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use crate::minituna_v1::Objective; | |
use crate::minituna_v1::Trial; | |
use crate::minituna_v1::TrialError; | |
struct Quadratic; | |
impl Objective for Quadratic { | |
fn objective(&self, trial: Trial) -> Result<f64, TrialError> { | |
let x = trial.suggest_uniform("x", 0.0, 10.0); | |
let y = trial.suggest_uniform("y", 0.0, 10.0); | |
match (x, y) { | |
(Ok(x1), Ok(y1)) => Ok((x1 - 3.0).powi(2) + (y1 - 5.0).powi(2)), | |
(_, Err(ey1)) => Err(ey1), | |
(Err(ex1), _) => Err(ex1), | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use ordered_float::OrderedFloat; | |
use rand::prelude::*; | |
use rand::rngs::StdRng; | |
use rand::SeedableRng; | |
use std::{cell::RefCell, collections::HashMap}; | |
pub struct TrialError { | |
message: String, | |
} | |
impl TrialError { | |
fn new(message: &str) -> TrialError { | |
TrialError { | |
message: String::from(message), | |
} | |
} | |
} | |
#[derive(PartialEq, Clone)] | |
pub enum TrialState { | |
Running, | |
Completed, | |
Failed, | |
} | |
#[derive(Clone)] | |
pub struct FrozenTrial { | |
trial_id: u32, | |
state: TrialState, | |
value: Option<OrderedFloat<f64>>, | |
params: HashMap<String, f64>, | |
} | |
impl FrozenTrial { | |
pub fn new(trial_id: u32) -> FrozenTrial { | |
FrozenTrial { | |
trial_id, | |
state: TrialState::Running, | |
value: None, | |
params: HashMap::new(), | |
} | |
} | |
pub fn is_finished(&self) -> bool { | |
self.state != TrialState::Running | |
} | |
} | |
#[derive(Clone)] | |
pub struct Storage { | |
trials: Vec<FrozenTrial>, | |
} | |
impl Storage { | |
pub fn create_new_trial(&mut self) -> u32 { | |
let trial_id = self.trials.len() as u32; | |
let trial = FrozenTrial::new(trial_id); | |
self.trials.push(trial); | |
trial_id | |
} | |
pub fn get_trial(&self, trial_id: u32) -> Option<FrozenTrial> { | |
self.trials.get(trial_id as usize).map(|v| v.clone()) | |
} | |
pub fn get_best_trial(&self) -> Option<FrozenTrial> { | |
let completed_trials: Vec<FrozenTrial> = self | |
.trials | |
.iter() | |
.filter(|trial| trial.state == TrialState::Completed) | |
.map(|v| v.clone()) | |
.collect(); | |
let best_trial = completed_trials.into_iter().min_by_key(|t| t.value); | |
best_trial | |
} | |
pub fn set_trial_value(&mut self, trial_id: u32, value: f64) -> Result<(), TrialError> { | |
let maybe_trial = self.trials.get_mut(trial_id as usize); | |
if let Some(trial) = maybe_trial { | |
if !trial.is_finished() { | |
return Err(TrialError::new("cannot update finished trial")); | |
} | |
trial.value = Some(OrderedFloat::from(value)); // TODO いけてんの?? | |
} | |
Ok(()) | |
} | |
pub fn set_trial_state(&mut self, trial_id: u32, state: TrialState) -> Result<(), TrialError> { | |
let maybe_trial = self.trials.get_mut(trial_id as usize); | |
if let Some(trial) = maybe_trial { | |
if !trial.is_finished() { | |
return Err(TrialError::new("cannot update finished trial")); | |
} | |
trial.state = state; | |
} | |
Ok(()) | |
} | |
pub fn set_trial_param( | |
&mut self, | |
trial_id: u32, | |
name: &str, | |
value: f64, | |
) -> Result<(), TrialError> { | |
let maybe_trial = self.trials.get_mut(trial_id as usize); | |
if let Some(trial) = maybe_trial { | |
if !trial.is_finished() { | |
return Err(TrialError::new("cannot update finished trial")); | |
} | |
trial.params.insert(name.to_string(), value); | |
} | |
Ok(()) | |
} | |
} | |
pub struct Trial { | |
study: RefCell<Study>, | |
trial_id: u32, | |
state: TrialState, | |
} | |
impl Trial { | |
pub fn new(trial_id: u32, study: &Study) -> Self { | |
Trial { | |
study: RefCell::new(study.clone()), // TODO check performance | |
trial_id, | |
state: TrialState::Running, | |
} | |
} | |
pub fn suggest_uniform(&self, name: &str, low: f64, high: f64) -> Result<f64, TrialError> { | |
let maybe_trial = self.study.borrow().storage.get_trial(self.trial_id); | |
if let Some(trial) = maybe_trial { | |
let mut distribution = HashMap::new(); | |
distribution.insert(String::from("low"), low); | |
distribution.insert(String::from("high"), high); | |
let param = self.study.borrow_mut().sampler.sample_independent( | |
&self.study.borrow(), | |
&trial, | |
name, | |
distribution, | |
); | |
match self | |
.study | |
.borrow_mut() | |
.storage | |
.set_trial_param(self.trial_id, name, param) | |
{ | |
Ok(_) => Ok(param), | |
Err(err) => Err(err), | |
} | |
} else { | |
Err(TrialError::new("Not found specific trial")) | |
} | |
} | |
} | |
#[derive(Clone)] | |
pub struct Sampler { | |
rng: StdRng, | |
} | |
impl Sampler { | |
pub fn new(seed: u64) -> Self { | |
let rng = SeedableRng::seed_from_u64(seed); | |
Sampler { rng } | |
} | |
pub fn sample_independent( | |
&mut self, | |
_study: &Study, | |
_trial: &FrozenTrial, | |
_name: &str, | |
distribution: HashMap<String, f64>, | |
) -> f64 { | |
assert!(distribution.get("low").is_some()); | |
assert!(distribution.get("high").is_some()); | |
self.rng.gen_range( | |
distribution.get("low").unwrap(), | |
distribution.get("high").unwrap(), | |
) | |
} | |
} | |
#[derive(Clone)] | |
pub struct Study { | |
storage: Storage, | |
sampler: Sampler, | |
} | |
impl Study { | |
pub fn optimize<T: Objective>(&mut self, objective: T, n_trials: u32) { | |
for _ in 0..n_trials { | |
let trial_id = self.storage.create_new_trial(); | |
let trial = Trial::new(trial_id, self); | |
let value = objective.objective(trial); | |
let result = value | |
.and_then(|v| self.storage.set_trial_value(trial_id, v)) | |
.and_then(|_| { | |
self.storage | |
.set_trial_state(trial_id, TrialState::Completed) | |
}); | |
match result { | |
Ok(()) => (), | |
Err(err) => eprintln!("trial_id={} is failed by {}", trial_id, err.message), | |
} | |
} | |
} | |
pub fn best_trial(self) -> Option<FrozenTrial> { | |
self.storage.get_best_trial() | |
} | |
} | |
pub trait Objective { | |
fn objective(&self, trial: Trial) -> Result<f64, TrialError>; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment