Skip to content

Instantly share code, notes, and snippets.

@inferrna
Last active August 4, 2020 16:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save inferrna/b1b49b9d0e161a5104670ecb7870e19a to your computer and use it in GitHub Desktop.
Save inferrna/b1b49b9d0e161a5104670ecb7870e19a to your computer and use it in GitHub Desktop.
Try to build multilayer perceptron based on https://github.com/raskr/rust-autograd
extern crate autograd as ag;
extern crate ndarray;
use ag::{ndarray_ext as arr, NdArray};
use ag::optimizers::adam;
use ag::rand::seq::SliceRandom;
use ag::tensor::Variable;
use ag::ndarray::s;
use rand::RngCore;
use std::ops::{Mul, Add, Div};
const TAU: f32 = 6.28318530717958647692528676655900577_f32;
fn get_permutation(size: usize) -> Vec<usize> {
let mut perm: Vec<usize> = (0..size).collect();
perm.shuffle(&mut rand::thread_rng());
perm
}
fn main() {
let rng = ag::ndarray_ext::ArrayRng::<f32>::default();
let w_arr = arr::into_shared(rng.glorot_uniform(&[2, 1]));
//Additional layer(s)
let w2_arr = arr::into_shared(rng.glorot_uniform(&[1, 32]));
let w3_arr = arr::into_shared(rng.glorot_uniform(&[32, 1]));
let b_arr = arr::into_shared(arr::zeros(&[1, 1]));
let adam_state = adam::AdamState::new(&[&w_arr, &w2_arr, &w3_arr, &b_arr]);
let num_samples = 10000;
let max_epoch = 3000;
//All data
let (mut x_values, mut y_values): (Vec<Vec<f32>>, Vec<f32>) = (0..num_samples)
.map(|i| (i as f32 / num_samples as f32))
.map(|i|
(vec![i.mul(TAU).sin().add(1.0).div(2.0),
i.mul(TAU).cos().add(1.0).div(2.0)],
i))
.unzip();
//Test data
let (x_test, y_test) = {
let mut rnd = rand::thread_rng();
let mut x_test: Vec<Vec<f32>> = vec![];
let mut y_test: Vec<f32> = vec![];
for _ in 0..num_samples/100 {
let idx = rnd.next_u64() as usize % x_values.len();
x_test.push(x_values.remove(idx));
y_test.push(y_values.remove(idx));
}
(x_test, y_test)
};
let train_samples = y_values.len();
let test_samples = y_test.len();
let x_train: Vec<f32> = x_values.into_iter().flatten().collect();
let x_test: Vec<f32> = x_test.into_iter().flatten().collect();
//Check data correctness
let sanity_test_idx = 69;
assert_eq!(y_values[sanity_test_idx].mul(TAU).sin().add(1.0).div(2.0), x_train[sanity_test_idx*2]);
assert_eq!(y_values[sanity_test_idx].mul(TAU).cos().add(1.0).div(2.0), x_train[sanity_test_idx*2+1]);
assert_eq!(y_test[sanity_test_idx].mul(TAU).sin().add(1.0).div(2.0), x_test[sanity_test_idx*2]);
assert_eq!(y_test[sanity_test_idx].mul(TAU).cos().add(1.0).div(2.0), x_test[sanity_test_idx*2+1]);
//Convert data to ndarray
let as_arr = NdArray::from_shape_vec;
let y_train = as_arr(ag::ndarray::IxDyn(&[train_samples, 1]), y_values).unwrap();
let x_train = as_arr(ag::ndarray::IxDyn(&[train_samples, 2]), x_train).unwrap();
let y_test = as_arr(ag::ndarray::IxDyn(&[test_samples, 1]), y_test).unwrap();
let x_test = as_arr(ag::ndarray::IxDyn(&[test_samples, 2]), x_test).unwrap();
//Train
for epoch in 0..max_epoch {
ag::with(|g| {
let w = g.variable(w_arr.clone());
let w2 = g.variable(w2_arr.clone());
let w3 = g.variable(w3_arr.clone());
let b = g.variable(b_arr.clone());
let x = g.placeholder(&[-1,2]);
let y = g.placeholder(&[-1,1]);
let xw = g.matmul(x, w);
let xww2 = g.matmul(xw, w2);
let z = g.sigmoid(g.matmul(xww2, w3) + b);
let mean_loss = g.reduce_mean(g.square(g.sub(z, &y)), &[0,1], false);
if epoch % 1000 == 0 || (epoch < 1000 && epoch % 10 == 0) {
let acc = mean_loss.eval(&[x.given(x_train.view()), y.given(y_train.view())]).unwrap();
println!(
"Epoch {}, train error: {:?}",
epoch, acc.view()
);
}
let grads = &g.grad(&[&mean_loss], &[w, w2, w3, b]);
let update_ops: &[ag::Tensor<f32>] =
&adam::Adam::default().compute_updates(&[w, w2, w3, b], grads, &adam_state, g);
let batch_size = 50isize;
let num_batches = train_samples / batch_size as usize;
for i in get_permutation(num_batches) {
let i = i as isize * batch_size;
let y_batch = y_train.slice(s![i..i + batch_size, ..]).into_dyn();
let x_batch = x_train.slice(s![i..i + batch_size, ..]).into_dyn();
g.eval(update_ops, &[x.given(x_batch), y.given(y_batch)]);
}
});
}
//Test
ag::with(|g| {
let w = g.variable(w_arr.clone());
let w2 = g.variable(w2_arr.clone());
let w3 = g.variable(w3_arr.clone());
let b = g.variable(b_arr.clone());
let x = g.placeholder(&[-1,2]);
let y = g.placeholder(&[-1,1]);
// -- test --
let xw = g.matmul(x, w);
let xww2 = g.matmul(xw, w2);
let z = g.sigmoid(g.matmul(xww2, w3) + b);
let predictions = z;
let accuracy = g.reduce_mean(g.square(g.sub(predictions, &y)), &[0,1], false);
let acc = accuracy.eval(&[x.given(x_test.view()), y.given(y_test.view())]).unwrap();
let values = z.eval(&[x.given(x_test.view()), y.given(y_test.view())]).unwrap();
println!(
"test error: {:?}, result values = \n{:?}\noriginal values = \n{:?}",
acc.view(), values.view(), y_test.view()
);
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment