Created
March 16, 2023 16:50
-
-
Save shtsoft/ba971c25e5bd559cc19cd07dff9dc761 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mod probabilistic_programming { | |
pub mod print_helpers { | |
pub fn print_bool_distribution(vec: Vec<(f64, bool)>) { | |
let mut trues = 0.0; | |
let mut falses = 0.0; | |
for (weight, outcome) in vec { | |
if outcome { | |
trues += weight; | |
} else { | |
falses += weight; | |
} | |
} | |
println!( | |
"BOOL DISTRIBUTION: {:.2}% TRUE vs {:.2}% FALSE", | |
trues * 100.0 / (trues + falses), | |
falses * 100.0 / (trues + falses) | |
); | |
} | |
pub fn print_mean(vec: Vec<(f64, f64)>) { | |
let mut outcome_sum = 0.0; | |
let mut weight_sum = 0.0; | |
for (weight, outcome) in vec { | |
outcome_sum += weight * outcome; | |
weight_sum += weight; | |
} | |
println!("MEAN: {:.2}", outcome_sum / weight_sum); | |
} | |
} | |
mod inference_algorithms { | |
pub struct Particle<Handler, Outcome> { | |
pub handler: Handler, | |
pub outcome: Option<Outcome>, | |
} | |
type Posterior<Outcome> = Vec<(f64, Outcome)>; | |
type Model<H, O> = fn(H) -> Particle<H, O>; | |
type Continuation<'a, H, O> = Box<dyn Fn(H) -> Particle<H, O> + 'a>; | |
pub trait Likelihood<'a, O: Copy + 'a> { | |
fn score(self, p: f64, k: Continuation<'a, Self, O>) -> Particle<Self, O> | |
where | |
Self: Sized; | |
} | |
macro_rules! execute { | |
($execution:expr, $number_of_particles:expr, $particles:expr) => { | |
for _ in 0..$number_of_particles { | |
$particles.push($execution); | |
} | |
}; | |
} | |
macro_rules! measure { | |
($particles:expr) => { | |
$particles | |
.iter() | |
.map(|particle| (particle.handler.weight, particle.outcome.unwrap())) | |
.collect() | |
}; | |
} | |
pub mod importance_sampling { | |
use super::*; | |
const DEFAULT_WEIGHT: f64 = 1.0; | |
pub struct WeighWeight { | |
weight: f64, | |
} | |
impl WeighWeight { | |
const fn new(weight: f64) -> Self { | |
Self { weight } | |
} | |
} | |
impl<'a, O: Copy + 'a> Likelihood<'a, O> for WeighWeight { | |
fn score(self, p: f64, k: Continuation<'a, Self, O>) -> Particle<Self, O> { | |
k(Self::new(self.weight * p)) | |
} | |
} | |
pub fn importance_sampling<'a, O: Copy + 'a>( | |
model: Model<WeighWeight, O>, | |
number_of_particles: usize, | |
) -> Posterior<O> { | |
let mut particles = Vec::with_capacity(number_of_particles); | |
execute!( | |
model(WeighWeight::new(DEFAULT_WEIGHT)), | |
number_of_particles, | |
particles | |
); | |
measure!(particles) | |
} | |
} | |
pub mod smc { | |
use super::*; | |
use rand::distributions::{Distribution, Uniform}; | |
use rand::thread_rng; | |
const DEFAULT_WEIGHT: f64 = 1.0; | |
pub struct PauseScoring<'a, O: Copy + 'a> { | |
weight: f64, | |
continuation: Option<Continuation<'a, PauseScoring<'a, O>, O>>, | |
multiplicity: usize, | |
} | |
impl<'a, O: Copy + 'a> PauseScoring<'a, O> { | |
fn new(weight: f64) -> Self { | |
Self { | |
weight, | |
continuation: None, | |
multiplicity: 0, | |
} | |
} | |
} | |
impl<'a, O: Copy + 'a> Likelihood<'a, O> for PauseScoring<'a, O> { | |
fn score(self, p: f64, k: Continuation<'a, Self, O>) -> Particle<Self, O> | |
where | |
Self: Sized, | |
{ | |
Particle { | |
handler: Self { | |
weight: self.weight * p, | |
continuation: Some(k), | |
multiplicity: self.multiplicity, | |
}, | |
outcome: None, | |
} | |
} | |
} | |
type Particles<'a, O> = Vec<Particle<PauseScoring<'a, O>, O>>; | |
fn resample<O: Copy>(particles: &mut Particles<O>) { | |
let mut sum_weight = 0.0; | |
for particle in particles.iter_mut() { | |
sum_weight += particle.handler.weight; | |
particle.handler.weight = sum_weight; | |
} | |
for _ in 0..(particles.len()) { | |
let weight_barrier = Uniform::new(0.0, sum_weight).sample(&mut thread_rng()); | |
for particle in particles.iter_mut() { | |
if weight_barrier < particle.handler.weight { | |
particle.handler.multiplicity += 1; | |
break; | |
} | |
} | |
} | |
} | |
fn propagate<O: Copy>(particles: &mut Particles<O>) { | |
let mut propagated_particles = Vec::new(); | |
for particle in particles.iter_mut() { | |
let outcome = particle.outcome; | |
let continuation = particle.handler.continuation.take(); | |
let multiplicity = particle.handler.multiplicity; | |
match (outcome, continuation) { | |
(None, None) => { | |
panic!("Something went terribly wrong!") | |
} | |
(None, Some(k)) => { | |
execute!( | |
k(PauseScoring::new(DEFAULT_WEIGHT)), | |
multiplicity, | |
propagated_particles | |
) | |
} | |
(Some(c), _) => { | |
execute!( | |
Particle { | |
handler: PauseScoring::new(DEFAULT_WEIGHT), | |
outcome: Some(c) | |
}, | |
multiplicity, | |
propagated_particles | |
) | |
} | |
}; | |
} | |
*particles = propagated_particles; | |
} | |
fn filter<O: Copy>(particles: &mut Particles<O>, number_of_filter_steps: usize) { | |
for _ in 0..number_of_filter_steps { | |
resample(particles); | |
propagate(particles); | |
} | |
} | |
fn finalize<O: Copy>(particles: &mut Particles<O>) { | |
for particle in particles.iter_mut() { | |
while particle.outcome.is_none() { | |
let continuation = particle.handler.continuation.take(); | |
let weight = particle.handler.weight; | |
match continuation { | |
None => panic!("Something went terribly wrong!"), | |
Some(k) => *particle = k(PauseScoring::new(weight)), | |
} | |
} | |
} | |
} | |
pub fn smc<'a, O: Copy + 'a>( | |
model: Model<PauseScoring<'a, O>, O>, | |
number_of_particles: usize, | |
number_of_filter_steps: usize, | |
) -> Posterior<O> { | |
let mut particles = Vec::with_capacity(number_of_particles); | |
execute!( | |
model(PauseScoring::new(DEFAULT_WEIGHT)), | |
number_of_particles, | |
particles | |
); | |
filter(&mut particles, number_of_filter_steps); | |
finalize(&mut particles); | |
measure!(particles) | |
} | |
} | |
pub use importance_sampling::*; | |
pub use smc::*; | |
} | |
pub use inference_algorithms::*; | |
pub use print_helpers::*; | |
use rand::distributions::{Bernoulli, Distribution}; | |
use rand::thread_rng; | |
use statrs::distribution::{Continuous, Normal}; | |
pub fn sprinkler<'a, H>(score_handler: H) -> Particle<H, bool> | |
where | |
H: Likelihood<'a, bool>, | |
{ | |
let rain = Bernoulli::new(0.2).unwrap().sample(&mut thread_rng()); | |
let sprinkler = Bernoulli::new(0.1).unwrap().sample(&mut thread_rng()); | |
let probability_lawn_wet = if rain { | |
if sprinkler { | |
0.99 | |
} else { | |
0.7 | |
} | |
} else if sprinkler { | |
0.9 | |
} else { | |
0.01 | |
}; | |
score_handler.score( | |
probability_lawn_wet, | |
Box::new(move |score_handler| Particle { | |
handler: score_handler, | |
outcome: Some(rain), | |
}), | |
) | |
} | |
trait RecursiveScoring { | |
fn fold<'a, H>(mut self, score_handler: H) -> Particle<H, f64> | |
where | |
H: Likelihood<'a, f64>, | |
Self: Copy + Sized + 'a, | |
{ | |
let outcome = self.outcome(); | |
let probability = self.probability(); | |
self = self.step(); | |
let is_base = self.is_base(); | |
if is_base { | |
score_handler.score( | |
probability, | |
Box::new(move |score_handler| Particle { | |
handler: score_handler, | |
outcome, | |
}), | |
) | |
} else { | |
score_handler.score( | |
probability, | |
Box::new(move |score_handler| self.fold(score_handler)), | |
) | |
} | |
} | |
fn is_base(&self) -> bool; | |
fn step(self) -> Self; | |
fn outcome(&self) -> Option<f64>; | |
fn probability(&self) -> f64; | |
} | |
#[derive(Copy, Clone)] | |
struct LinearRegression<'a, const N: usize> { | |
slope: f64, | |
data: &'a [(f64, f64); N], | |
index: usize, | |
variance: f64, | |
} | |
impl<'a, const N: usize> RecursiveScoring for LinearRegression<'a, N> { | |
fn step(mut self) -> Self { | |
self.index -= 1; | |
self | |
} | |
fn is_base(&self) -> bool { | |
self.index == 0 | |
} | |
fn outcome(&self) -> Option<f64> { | |
Some(self.slope) | |
} | |
fn probability(&self) -> f64 { | |
let slope = self.slope; | |
let point = self.data[self.index]; | |
let x = point.0; | |
let y = point.1; | |
let variance = self.variance; | |
Normal::new(slope * x, variance).unwrap().pdf(y) | |
} | |
} | |
pub fn linear_regression<'a, H, const N: usize>( | |
score_handler: H, | |
slope: f64, | |
data: &'a [(f64, f64); N], | |
variance: f64, | |
) -> Particle<H, f64> | |
where | |
H: Likelihood<'a, f64>, | |
{ | |
LinearRegression { | |
slope, | |
data, | |
index: N - 1, | |
variance, | |
} | |
.fold(score_handler) | |
} | |
const SIZE: usize = 3; | |
const DATA: [(f64, f64); SIZE] = [(1.0, 2.2), (2.0, 3.8), (3.0, 5.8)]; | |
const VARIANCE: f64 = 0.25; | |
pub fn linear_regression_example<'a, H>(score_handler: H) -> Particle<H, f64> | |
where | |
H: Likelihood<'a, f64>, | |
{ | |
linear_regression( | |
score_handler, | |
Normal::new(0.0, 3.0).unwrap().sample(&mut thread_rng()), | |
&DATA, | |
VARIANCE, | |
) | |
} | |
} | |
use probabilistic_programming::*; | |
fn main() { | |
println!("------------------------------"); | |
print_bool_distribution(importance_sampling(sprinkler, 1000)); | |
println!("------------------------------"); | |
print_bool_distribution(smc(sprinkler, 1000, 2)); | |
println!("------------------------------"); | |
print_mean(smc(linear_regression_example, 1000, 2)); | |
println!("------------------------------"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment