Last active
August 18, 2017 15:36
-
-
Save mandubian/f640696bfe8d3528d5e998e4e528302e to your computer and use it in GitHub Desktop.
Experimenting reverse auto-differentiation with abstract DSL in Rust & ArrayFire (which is just a generalized form of back-propagation)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Experimenting reverse auto-differentiation (which is just a generalized form of back-propagation) with: | |
// - a DSL to build expressions without supposing how it will be differentiated | |
// - values are based on ArrayFire to be able to run on CPU or GPU backends (DSL isn't abstract for values) | |
// - a context of differentiation in which the DSL expression is evaluated & which is stateful to keep track of partial computed values without recomputing everything for each differentiations | |
// Nothing very new here, just playing with those concepts to see Rust behavior in trickier contexts | |
// | |
// Inspired by different projects doing the same: | |
// - autodifferentiation in haskell https://hackage.haskell.org/package/ad | |
// - arrayfire-rust & arrayfire-ml https://github.com/arrayfire/arrayfire-ml | |
// - rust rugrad https://github.com/AtheMathmo/rugrads | |
// First alpha sample demo'ing the eventual API | |
// Running on OPENCL (could be CPU) | |
af::set_backend(af::Backend::OPENCL); | |
af::info(); | |
// Create a few arrays of data | |
let d_dims = Dim4::new(&[1, 6, 1, 1]); | |
let d_input: [i32; 6] = [1, 2, 3, 4, 5, 6]; | |
let d = Array::new(&d_input, d_dims); | |
let d2 = Array::new(&d_input, d_dims); | |
// building the expression (independently of ctx) | |
var!(x, array(d)); | |
var!(y, array(d2)); | |
var!(z, expr(&x + &y)); // x is used here | |
var!(t, expr(Sin(&z))); | |
var!(u, expr(&x + &t)); // x is used here too | |
// Create a stateful context of derivation | |
let mut ctx = Ctx::new(); | |
// Evaluate expression in derivation Ctx | |
let ev = u.eval(&ctx); | |
// Differentiate on the variable x/y/z/t directly | |
// Ctx being stateful, partial grads are computed once & kept so first computation might be longer but later it's much faster | |
let dz = d!(ev / d(z)).unwrap(); | |
af_print!("d/dz:", dz); | |
let dx = d!(ev / d(x)).unwrap(); | |
af_print!("d/dx:", dx); | |
let dy = d!(ev / d(y)).unwrap(); | |
af_print!("d/dy:", dy); | |
let dt = d!(ev / d(t)).unwrap(); | |
af_print!("d/dt:", dt); | |
// next steps: | |
// - current context is not threadsafe but auto-differentiation process could be parallelized in many ways... TBC... | |
// - backpropagation can be optimized when differentiating on a single variable to build only the partial gradients it needs & not the whole graph | |
// - current API to build expressions is clunky due to reference & owning, will improve that later | |
// - Build bigger expressions in procedural way & test how it behaves | |
// - ... Then naturally build neural networks with that & play a bit ;) | |
/* | |
d/dz: | |
[1 6 1 1] | |
-0.4161 -0.6536 0.9602 -0.1455 -0.8391 0.8439 | |
d/dx: | |
[1 6 1 1] | |
0.5839 0.3464 1.9602 0.8545 0.1609 1.8439 | |
d/dy: | |
[1 6 1 1] | |
-0.4161 -0.6536 0.9602 -0.1455 -0.8391 0.8439 | |
d/dt: | |
[1 6 1 1] | |
1 1 1 1 1 1 | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment