Skip to content

Instantly share code, notes, and snippets.

@neftlon
Created March 21, 2023 10:26
Show Gist options
  • Save neftlon/f4cae6d6a2a389d156b62b6eafd16f1a to your computer and use it in GitHub Desktop.
Save neftlon/f4cae6d6a2a389d156b62b6eafd16f1a to your computer and use it in GitHub Desktop.
mini autograd engine in Rust
use std::{
cell::RefCell,
ops::{Add, Mul, Neg},
rc::Rc,
};
#[derive(Debug, Clone)]
struct Var {
inner: VarInnerRef,
}
/// tuple: value, gradient, and parents
type VarInnerRef = Rc<RefCell<(f32, f32, VarInner)>>;
/// structure to refer to `Var` parents
#[derive(Debug)]
enum VarInner {
Leaf,
Neg(VarInnerRef),
Add(VarInnerRef, VarInnerRef),
Mul(VarInnerRef, VarInnerRef),
}
impl Var {
fn new(inner: VarInnerRef) -> Self {
Self {
inner: Rc::clone(&inner),
}
}
/// create a `Var` from a single number
fn num(num: f32) -> Self {
Self::new(Rc::new(RefCell::new((num, 0.0, VarInner::Leaf))))
}
/// get the value of this `Var`
fn val(&self) -> f32 {
self.inner.borrow().0
}
/// get the gradient of this `Var`
fn grad(&self) -> f32 {
self.inner.borrow().1
}
/// backward pass
fn bwd(&self) {
VarInner::bwd(&self.inner.borrow(), 1.)
}
}
impl Add for Var {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self::new(Rc::new(RefCell::new((
self.val() + rhs.val(),
0.0,
VarInner::Add(Rc::clone(&self.inner), Rc::clone(&rhs.inner)),
))))
}
}
impl Mul for Var {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
Self::new(Rc::new(RefCell::new((
self.val() * rhs.val(),
0.0,
VarInner::Mul(Rc::clone(&self.inner), Rc::clone(&rhs.inner)),
))))
}
}
impl Neg for Var {
type Output = Self;
fn neg(self) -> Self::Output {
Self::new(Rc::new(RefCell::new((
-self.val(),
0.0,
VarInner::Neg(Rc::clone(&self.inner)),
))))
}
}
impl VarInner {
fn bwd(this: &(f32, f32, VarInner), upstream: f32) {
match &this.2 {
Self::Leaf => {}
Self::Mul(p1, p2) => {
let (up1, up2) = {
let (p1val, ref mut p1grad, _) = &mut *p1.borrow_mut();
let (p2val, ref mut p2grad, _) = &mut *p2.borrow_mut();
let ups = (upstream * *p2val, upstream * *p1val);
*p1grad += ups.0;
*p2grad += ups.1;
ups
};
Self::bwd(&p1.borrow(), up1);
Self::bwd(&p2.borrow(), up2);
}
Self::Add(p1, p2) => {
{
let (_, ref mut p1grad, _) = &mut *p1.borrow_mut();
let (_, ref mut p2grad, _) = &mut *p2.borrow_mut();
*p1grad += upstream;
*p2grad += upstream;
}
Self::bwd(&p1.borrow(), upstream);
Self::bwd(&p2.borrow(), upstream);
}
Self::Neg(p) => {
let up = -upstream;
{
let (_, ref mut pgrad, _) = &mut *p.borrow_mut();
*pgrad += up;
}
Self::bwd(&p.borrow(), up);
}
}
}
}
fn main() {
let x = Var::num(1.);
let y = Var::num(2.);
let z = Var::num(3.);
let w = Var::num(4.);
let r2 = (-x.clone() + y.clone() * w.clone()) * z.clone();
r2.bwd();
dbg!(&x, &y, &z, &w);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment