type Mat = [[f64; 4]; 3];
const P: Option<f64> = Some(0.1);
fn chopped(exact: f64, p: Option<f64>) -> f64 {
if let Some(p) = p {
let exact = exact * (1.0 + p);
let exact = exact / (p / 10.0);
exact.trunc() * (p / 10.0)
} else {
exact
}
}
fn reduce(mat: Mat, column: usize, row: usize, p: Option<f64>) -> Mat {
let c = column;
let r = row;
let mut dest = mat;
for (i, r_v) in mat[r + 1..].iter().enumerate() {
let r_i = r + i + 1;
let m = chopped(mat[r_i][c] / mat[r][c], p);
for (j, &c_v) in r_v[c..].iter().enumerate() {
if j == 0 {
dest[r_i][c + j] = 0.0;
continue;
}
let t = chopped(m * mat[r][c + j], p);
dest[r_i][c + j] = chopped(c_v - t, p);
}
}
dest
}
fn pivot(mut v: Mat, i: usize) -> Mat {
use std::cmp::*;
v[i..].sort_by(|l_row, r_row| {
let l = l_row[i].abs();
let r = r_row[i].abs();
l.partial_cmp(&r)
.unwrap()
.reverse()
});
v
}
fn backprop(v: Mat, p: Option<f64>) -> [f64; 3] {
let mut dest: [f64; 3] = Default::default();
let v_iter = (0..v.len())
.map(|r| {
let r = v.len() - (r + 1);
let end = v[r].len() - 1;
let b = (r..end).fold(*v[r].last().unwrap(), |b, c| {
let t = v[r][c] * v[r][c + 1];
let t = chopped(t, p);
chopped(b - t, p)
});
chopped(b / v[r][end], p)
});
for (dest, b) in dest.iter_mut().zip(v_iter) {
*dest = b;
}
dest
}
fn main() {
// row major #yolo
let initial: Mat = [[1.0, 1.0/2.0, 1.0/3.0, 1.0],
[1.0/2.0, 1.0/3.0, 1.0/4.0, 0.0],
[1.0/3.0, 1.0/4.0, 1.0/3.0, 0.0]];
let initial_chopped: Mat = {
let mut dest: Mat = Default::default();
for i in 0..3 {
for j in 0..4 {
dest[i][j] = chopped(initial[i][j], P);
}
}
dest
};
println!("initial = `{:?}`", initial);
println!("initial_chopped = `{:?}`", initial_chopped);
let r1 = reduce(initial_chopped, 0, 0, P);
println!("r1 = `{:?}`", r1);
let r2 = reduce(r1, 1, 1, P);
println!("r2 = `{:?}`", r2);
let b1 = backprop(r2, P);
println!("without partial pivoting: b = `{:?}`", b1);
let p1 = pivot(initial_chopped, 0);
println!("p1 = `{:?}`", p1);
let r1 = reduce(p1, 0, 0, P);
println!("r1 = `{:?}`", r1);
let p2 = pivot(r1, 1);
println!("p2 = `{:?}`", p2);
let r2 = reduce(p2, 1, 1, P);
println!("r2 = `{:?}`", r2);
let b2 = backprop(r2, P);
println!("with partial pivoting: b = `{:?}`", b2);
println!("");
println!("******* without chopping: *******");
let r1 = reduce(initial_chopped, 0, 0, None);
println!("r1 = `{:?}`", r1);
let r2 = reduce(r1, 1, 1, None);
println!("r2 = `{:?}`", r2);
let b1 = backprop(r2, None);
println!("without partial pivoting: b = `{:?}`", b1);
let p1 = pivot(initial_chopped, 0);
println!("p1 = `{:?}`", p1);
let r1 = reduce(p1, 0, 0, None);
println!("r1 = `{:?}`", r1);
let p2 = pivot(r1, 1);
println!("p2 = `{:?}`", p2);
let r2 = reduce(p2, 1, 1, None);
println!("r2 = `{:?}`", r2);
let b2 = backprop(r2, None);
println!("with partial pivoting: b = `{:?}`", b2);
}
#[test]
fn chopped_test() {
assert_eq!(chopped(1.0, 0.1), 1.1);
}
Last active
September 8, 2016 14:18
-
-
Save DiamondLovesYou/e3c910af4fb7d71352f8d6d6dc03d835 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-*- mode: compilation; default-directory: "~/workspace/m781_2/src/" -*- | |
Compilation started at Thu Sep 8 09:16:36 | |
cargo run | |
Compiling num-traits v0.1.35 | |
Compiling m781_2 v0.1.0 (file:///C:/msys64/home/Richard/workspace/m781_2) | |
Finished debug [unoptimized + debuginfo] target(s) in 5.85 secs | |
Running `c:\msys64\home\Richard\workspace\m781_2\target\debug\m781_2.exe` | |
initial = `[[1, 0.5, 0.3333333333333333, 1], [0.5, 0.3333333333333333, 0.25, 0], [0.3333333333333333, 0.25, 0.3333333333333333, 0]]` | |
initial_chopped = `[[1.1, 0.55, 0.36, 1.1], [0.55, 0.36, 0.27, 0], [0.36, 0.27, 0.36, 0]]` | |
r1 = `[[1.1, 0.55, 0.36, 1.1], [0, 0.03, 0.06, -0.72], [0, 0.06, 0.24, -0.47000000000000003]]` | |
r2 = `[[1.1, 0.55, 0.36, 1.1], [0, 0.03, 0.06, -0.72], [0, 0, 0.1, 1.3900000000000001]]` | |
without partial pivoting: b = `[1.07, 1.25, -0.15]` | |
p1 = `[[1.1, 0.55, 0.36, 1.1], [0.55, 0.36, 0.27, 0], [0.36, 0.27, 0.36, 0]]` | |
r1 = `[[1.1, 0.55, 0.36, 1.1], [0, 0.03, 0.06, -0.72], [0, 0.06, 0.24, -0.47000000000000003]]` | |
p2 = `[[1.1, 0.55, 0.36, 1.1], [0, 0.06, 0.24, -0.47000000000000003], [0, 0.03, 0.06, -0.72]]` | |
r2 = `[[1.1, 0.55, 0.36, 1.1], [0, 0.06, 0.24, -0.47000000000000003], [0, 0, -0.08, -0.48]]` | |
with partial pivoting: b = `[1.3, 1.02, -0.15]` | |
******* without chopping: ******* | |
r1 = `[[1.1, 0.55, 0.36, 1.1], [0, 0.08499999999999996, 0.09000000000000002, -0.55], [0, 0.09000000000000002, 0.2421818181818182, -0.36]]` | |
r2 = `[[1.1, 0.55, 0.36, 1.1], [0, 0.08499999999999996, 0.09000000000000002, -0.55], [0, 0, 0.1468877005347593, 0.22235294117647098]]` | |
without partial pivoting: b = `[0.8531122994652407, 0.923909090909091, -0.09000000000000002]` | |
p1 = `[[1.1, 0.55, 0.36, 1.1], [0.55, 0.36, 0.27, 0], [0.36, 0.27, 0.36, 0]]` | |
r1 = `[[1.1, 0.55, 0.36, 1.1], [0, 0.08499999999999996, 0.09000000000000002, -0.55], [0, 0.09000000000000002, 0.2421818181818182, -0.36]]` | |
p2 = `[[1.1, 0.55, 0.36, 1.1], [0, 0.09000000000000002, 0.2421818181818182, -0.36], [0, 0.08499999999999996, 0.09000000000000002, -0.55]]` | |
r2 = `[[1.1, 0.55, 0.36, 1.1], [0, 0.09000000000000002, 0.2421818181818182, -0.36], [0, 0, -0.13872727272727253, -0.2100000000000003]]` | |
with partial pivoting: b = `[1.1387272727272726, 0.8183636363636364, -0.09000000000000002]` | |
Compilation finished at Thu Sep 8 09:16:46 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment