Created
April 22, 2020 05:01
-
-
Save cocomoff/2b9e85acd2f885240fa40ab22945e0e0 to your computer and use it in GitHub Desktop.
とてもシンプルなボルツマンマシン的なやつの計算
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#[derive(Debug)] | |
struct Boltzmann { | |
n: usize, // number of nodes {1,2,...,n} | |
edges: Vec<(usize, usize)>, // list of edges | |
bias: Vec<f64>, // node parameter (on node) | |
weights: Vec<f64>, // weight parameter (on edge) | |
} | |
impl Boltzmann { | |
fn phi(&self, x: &Vec<usize>) -> f64 { | |
let mut ans = 0.0; | |
for i in 0..self.n { | |
ans += if x[i] == 1 { self.bias[i] } else { 0.0 } | |
} | |
for j in 0..self.edges.len() { | |
let (ii, jj) = &self.edges[j]; | |
ans += self.weights[j] * (x[ii - 1] as f64) * (x[jj - 1] as f64) | |
} | |
-ans | |
} | |
pub fn energy(&self, all_x: &Vec<Vec<usize>>) -> f64 { | |
let mut total = 0.0; | |
for vec in all_x { | |
let phi_vec = self.phi(&vec); | |
total += (-phi_vec).exp(); | |
println!("{:?} {} {}", vec, phi_vec, (-phi_vec).exp()); | |
} | |
total | |
} | |
} | |
/* | |
example | |
(1) (2) | |
| \ | | |
| \ | | |
| \ | | |
(4)---(3) | |
Th(1,2,3,4) = (0.5, 0.1, 0.3, 0.4) | |
Th(13,14,23,34) = (0.6, 0.3, 0.9, 0.2) | |
*/ | |
fn main() { | |
// n = 4の例 | |
let all_x = vec![ | |
[0,0,0,0], [0,0,0,1], [0,0,1,0], [0,0,1,1], | |
[0,1,0,0], [0,1,0,1], [0,1,1,0], [0,1,1,1], | |
[1,0,0,0], [1,0,0,1], [1,0,1,0], [1,0,1,1], | |
[1,1,0,0], [1,1,0,1], [1,1,1,0], [1,1,1,1], | |
]; | |
let mut all_x_vec = Vec::new(); | |
for a in &all_x { all_x_vec.push(a.to_vec()); } | |
// 例の分配関数を求める | |
let n = 4; | |
let edges = vec![(1, 3), (1, 4), (2, 3), (3, 4)]; | |
let bias = vec![0.5, 0.1, 0.3, 0.4]; | |
let weights = vec![0.6, 0.3, 0.9, 0.2]; | |
let bm = Boltzmann {n:n, edges:edges, bias:bias, weights:weights}; | |
let energy = bm.energy(&all_x_vec); | |
println!("{:?}", bm); | |
println!("{}", energy); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment