Skip to content

Instantly share code, notes, and snippets.

@mbillingr
Last active August 16, 2023 13:48
Show Gist options
  • Save mbillingr/567d8c9dd3b01fe06e3995265e3a8be7 to your computer and use it in GitHub Desktop.
Save mbillingr/567d8c9dd3b01fe06e3995265e3a8be7 to your computer and use it in GitHub Desktop.
Automatic differentiation on tensors prototype
use std::collections::HashMap;
use std::rc::Rc;
fn main() {
let x = &Tensor::new(vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![2.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
vec![2.0, 1.0],
vec![0.0, 2.0],
vec![1.0, 2.0],
vec![2.0, 2.0],
]);
let y = &Tensor::new(vec![-5.0, 5.0, 15.0, -7.0, 3.0, 13.0, -9.0, 1.0, 11.0]);
println!("X: {}", x);
println!("Y: {}", y);
let w0 = Tensor::new(vec![0.0, 0.0]);
let b0 = Tensor::new(0.0);
//println!("{:?}", Sum::real_fn(&sqr(dot(&w, &x) + b - y)));
for t in gradient_descent(
&|t| {
let w = t[0].clone();
let b = t[1].clone();
sum(sqr(dot(w, x) + b - y))
},
vec![w0, b0],
1000,
0.01,
) {
println!("{}", t);
}
}
fn gradient_descent(
f: &impl Fn(&[Dual]) -> Dual,
mut theta: Vec<Tensor>,
n: usize,
alpha: f64,
) -> Vec<Tensor> {
for _ in 0..n {
let grd = gradient_of(f, &theta);
for (t, g) in theta.iter_mut().zip(grd) {
*t = t as &_ - &(g * alpha);
}
}
theta
}
fn gradient_of(f: &impl Fn(&[Dual]) -> Dual, theta: &[Tensor]) -> Vec<Tensor> {
let wrt: Vec<Dual> = theta.iter().cloned().map(Dual::end).collect();
let g = f(&wrt);
let mut sigma = g.grad().compute(&Tensor::scalar(1.0), HashMap::new());
wrt.iter().map(|d| sigma.remove(&d.id()).unwrap()).collect()
}
#[derive(Debug, Clone)]
struct Tensor {
store: Rc<Vec<f64>>,
shape: Vec<usize>,
offset: usize,
}
#[derive(Debug, Clone)]
struct Dual(Rc<(Tensor, Chain)>);
#[derive(Clone)]
enum Chain {
End(usize),
Prim1(Dual, fn(&Tensor, &Tensor) -> Tensor),
Prim2(
Dual,
Dual,
fn(&Tensor, &Tensor, &Tensor) -> (Tensor, Tensor),
),
}
trait Prim1 {
fn real_fn(t: &Tensor) -> Tensor;
fn grad_fn(t: &Tensor, z: &Tensor) -> Tensor;
fn apply(d: impl Into<Dual>) -> Dual {
let d = d.into();
Dual::new(Self::real_fn(&d.real()), Chain::Prim1(d, Self::grad_fn))
}
}
trait Prim2 {
fn real_fn(t: &Tensor, u: &Tensor) -> Tensor;
fn grad_fn(t: &Tensor, u: &Tensor, z: &Tensor) -> (Tensor, Tensor);
fn apply(da: impl Into<Dual>, db: impl Into<Dual>) -> Dual {
let (da, db) = (da.into(), db.into());
Dual::new(
Self::real_fn(&da.real(), &db.real()),
Chain::Prim2(da, db, Self::grad_fn),
)
}
}
// Core interface
impl Dual {
fn end(t: Tensor) -> Self {
let mut tmp = Rc::new((t, Chain::End(0)));
let id = &*tmp as *const _ as usize;
Rc::make_mut(&mut tmp).1 = Chain::End(id);
Dual(tmp)
}
fn new(t: Tensor, chain: Chain) -> Self {
Dual(Rc::new((t, chain)))
}
fn id(&self) -> usize {
&*self.0 as *const _ as usize
}
fn real(&self) -> &Tensor {
&self.0 .0
}
fn grad(&self) -> &Chain {
&self.0 .1
}
}
impl From<&Dual> for Dual {
fn from(d: &Dual) -> Dual {
d.clone()
}
}
impl From<Tensor> for Dual {
fn from(d: Tensor) -> Dual {
Dual::end(d)
}
}
impl From<&Tensor> for Dual {
fn from(d: &Tensor) -> Dual {
Dual::end(d.clone())
}
}
impl Tensor {
fn zero() -> Self {
Self::scalar(0.0)
}
fn scalar(x: f64) -> Self {
Tensor {
store: Rc::new(vec![x]),
shape: vec![],
offset: 0,
}
}
fn new(x: impl Into<Tensor>) -> Self {
x.into()
}
fn is_scalar(&self) -> bool {
self.shape.is_empty()
}
fn len(&self) -> usize {
if self.is_scalar() {
todo!("should scalars have lentgth?")
}
self.shape[0]
}
fn is_rank(&self, r: usize) -> bool {
self.shape.len() == r
}
fn rank_gt(&self, other: &Tensor) -> bool {
self.shape.len() > other.shape.len()
}
fn as_scalar(&self) -> f64 {
if !self.is_scalar() {
panic!("not a scalar")
}
self.store[self.offset]
}
fn elements(&self) -> impl Iterator<Item = Tensor> + '_ {
let stride = self.stride();
(0..self.len()).map(move |i| Tensor {
store: self.store.clone(),
shape: self.shape[1..].to_vec(),
offset: self.offset + i * stride,
})
}
fn map(&self, f: impl Fn(&Tensor) -> Tensor) -> Self {
if self.is_scalar() {
f(self)
} else {
let v: Vec<_> = self.elements().map(|t| f(&t)).collect();
Tensor::new(v)
}
}
fn stride(&self) -> usize {
let mut stride = 1;
for s in &self.shape[1..] {
stride *= s;
}
stride
}
fn size(&self) -> usize {
let mut size = 1;
for s in &self.shape {
size *= s;
}
size
}
}
impl std::fmt::Display for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if self.is_scalar() {
self.as_scalar().fmt(f)
} else {
write!(f, "[")?;
for (i, e) in self.elements().enumerate() {
if i > 0 {
write!(f, " ")?;
}
e.fmt(f)?;
}
write!(f, "]")
}
}
}
impl From<f64> for Tensor {
fn from(s: f64) -> Self {
Tensor::scalar(s)
}
}
impl<T: Into<Tensor>> From<Vec<T>> for Tensor {
fn from(v: Vec<T>) -> Self {
let ts = v.into_iter().map(Into::into).collect::<Vec<_>>();
let mut shape = vec![ts.len()];
shape.extend(&ts[0].shape);
let mut store = vec![];
for t in ts {
store.extend(&t.store[t.offset..t.offset + t.size()]);
}
Tensor {
store: Rc::new(store),
shape,
offset: 0,
}
}
}
impl FromIterator<Tensor> for Tensor {
fn from_iter<T: IntoIterator<Item = Tensor>>(it: T) -> Self {
let v: Vec<_> = it.into_iter().collect();
v.into()
}
}
fn tmap2(f: impl Fn(&Tensor, &Tensor) -> Tensor, ta: &Tensor, tb: &Tensor) -> Tensor {
ta.elements()
.zip(tb.elements())
.map(|(a, b)| f(&a, &b))
.collect()
}
// Internal stuff
impl std::fmt::Debug for Chain {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Chain::End(d) => write!(f, "({:?})", d),
Chain::Prim1(d, _) => write!(f, "Prim1({:?}, ...)", d),
Chain::Prim2(da, db, _) => write!(f, "Prim2({:?}, {:?}, ...)", da, db),
}
}
}
impl Chain {
fn compute(&self, z: &Tensor, mut sigma: HashMap<usize, Tensor>) -> HashMap<usize, Tensor> {
match self {
Chain::End(id) => {
match sigma.get(id) {
None => sigma.insert(*id, z.clone()),
Some(g) => sigma.insert(*id, z + g),
};
sigma
}
Chain::Prim1(d, grad_fn) => {
let ga = grad_fn(d.real(), z);
d.grad().compute(&ga, sigma)
}
Chain::Prim2(da, db, grad_fn) => {
let (ga, gb) = grad_fn(da.real(), db.real(), z);
let sigma = da.grad().compute(&ga, sigma);
db.grad().compute(&gb, sigma)
}
}
}
}
impl std::ops::Add for &Dual {
type Output = Dual;
fn add(self, other: &Dual) -> Self::Output {
Dual::new(
self.real() + other.real(),
Chain::Prim2(self.clone(), other.clone(), Add::grad_fn),
)
}
}
impl<T: Into<Dual>> std::ops::Add<T> for Dual {
type Output = Dual;
fn add(self, other: T) -> Self::Output {
let other = other.into();
Dual::new(
self.real() + other.real(),
Chain::Prim2(self, other, Add::grad_fn),
)
}
}
impl std::ops::Sub for &Dual {
type Output = Dual;
fn sub(self, other: &Dual) -> Self::Output {
Dual::new(
self.real() - other.real(),
Chain::Prim2(self.clone(), other.clone(), Sub::grad_fn),
)
}
}
impl<T: Into<Dual>> std::ops::Sub<T> for Dual {
type Output = Dual;
fn sub(self, other: T) -> Self::Output {
let other = other.into();
Dual::new(
self.real() - other.real(),
Chain::Prim2(self, other, Sub::grad_fn),
)
}
}
impl std::ops::Mul for Dual {
type Output = Dual;
fn mul(self, other: Dual) -> Self::Output {
Dual::new(
self.real() * other.real(),
Chain::Prim2(self, other, Mul::grad_fn),
)
}
}
impl std::ops::Mul for &Dual {
type Output = Dual;
fn mul(self, other: &Dual) -> Self::Output {
Dual::new(
self.real() * other.real(),
Chain::Prim2(self.clone(), other.clone(), Mul::grad_fn),
)
}
}
impl std::ops::Neg for Dual {
type Output = Dual;
fn neg(self) -> Self::Output {
Dual::new(-self.real(), Chain::Prim1(self, Neg::grad_fn))
}
}
impl Sqrt for Dual {
type Output = Dual;
fn sqrt(self) -> Self::Output {
Dual::new(
self.real().sqrt(),
Chain::Prim1(self.clone(), Sqrt_::grad_fn),
)
}
}
impl std::ops::Add for Tensor {
type Output = Tensor;
fn add(self, other: Tensor) -> Self::Output {
Add::real_fn(&self, &other)
}
}
impl std::ops::Add for &Tensor {
type Output = Tensor;
fn add(self, other: &Tensor) -> Self::Output {
Add::real_fn(self, other)
}
}
impl std::ops::Sub for Tensor {
type Output = Tensor;
fn sub(self, other: Tensor) -> Self::Output {
Sub::real_fn(&self, &other)
}
}
impl std::ops::Sub for &Tensor {
type Output = Tensor;
fn sub(self, other: &Tensor) -> Self::Output {
Sub::real_fn(self, other)
}
}
impl std::ops::Mul for Tensor {
type Output = Tensor;
fn mul(self, other: Tensor) -> Self::Output {
Mul::real_fn(&self, &other)
}
}
impl std::ops::Mul for &Tensor {
type Output = Tensor;
fn mul(self, other: &Tensor) -> Self::Output {
Mul::real_fn(self, other)
}
}
impl std::ops::Neg for Tensor {
type Output = Tensor;
fn neg(self) -> Tensor {
Neg::real_fn(&self)
}
}
impl std::ops::Neg for &Tensor {
type Output = Tensor;
fn neg(self) -> Tensor {
Neg::real_fn(self)
}
}
impl Sqrt for &Tensor {
type Output = Tensor;
fn sqrt(self) -> Self::Output {
Sqrt_::real_fn(self)
}
}
fn tmap_ext(
f: impl Fn(&Tensor, &Tensor, &Tensor) -> (Tensor, Tensor),
ta: &Tensor,
tb: &Tensor,
tc: &Tensor,
) -> (Tensor, Tensor) {
let mut gu: Vec<Tensor> = Vec::with_capacity(ta.len());
let mut gt: Vec<Tensor> = Vec::with_capacity(ta.len());
for ((a, b), c) in ta.elements().zip(tb.elements()).zip(tc.elements()) {
let (gti, gui) = f(&a, &b, &c);
gt.push(gti);
gu.push(gui);
}
(gt.into(), gu.into())
}
fn desc_u_grad(
g: impl Fn(&Tensor, &Tensor, &Tensor) -> (Tensor, Tensor),
t: &Tensor,
u: &Tensor,
z: &Tensor,
) -> (Tensor, Tensor) {
let mut gt = Tensor::zero();
let mut gu: Vec<Tensor> = Vec::with_capacity(u.len());
for (ui, zi) in u.elements().zip(z.elements()) {
let (gti, gui) = g(t, &ui, &zi);
gt = &gt + &gti;
gu.push(gui);
}
(gt.into(), gu.into())
}
fn desc_t_grad(
g: impl Fn(&Tensor, &Tensor, &Tensor) -> (Tensor, Tensor),
t: &Tensor,
u: &Tensor,
z: &Tensor,
) -> (Tensor, Tensor) {
let mut gt: Vec<Tensor> = Vec::with_capacity(t.len());
let mut gu = Tensor::zero();
for (ti, zi) in t.elements().zip(z.elements()) {
let (gti, gui) = g(&ti, u, &zi);
gt.push(gti);
gu = &gu + &gui;
}
(gt.into(), gu.into())
}
// Primitive definitions
struct Sum1;
impl Prim1 for Sum1 {
fn real_fn(t: &Tensor) -> Tensor {
assert!(t.is_rank(1));
let mut acc = 0.0;
for te in t.elements() {
acc += te.as_scalar();
}
Tensor::scalar(acc)
}
fn grad_fn(t: &Tensor, z: &Tensor) -> Tensor {
t.map(|_| z.clone())
}
}
struct Sqrt0;
impl Prim1 for Sqrt0 {
fn real_fn(t: &Tensor) -> Tensor {
Tensor::scalar(t.as_scalar().sqrt())
}
fn grad_fn(t: &Tensor, z: &Tensor) -> Tensor {
z * &Tensor::scalar((2.0 * t.as_scalar().sqrt()))
}
}
struct Neg0;
impl Prim1 for Neg0 {
fn real_fn(t: &Tensor) -> Tensor {
Tensor::scalar(-t.as_scalar())
}
fn grad_fn(t: &Tensor, z: &Tensor) -> Tensor {
-z.clone()
}
}
struct Add0_0;
impl Prim2 for Add0_0 {
fn real_fn(t: &Tensor, u: &Tensor) -> Tensor {
Tensor::scalar(t.as_scalar() + u.as_scalar())
}
fn grad_fn(t: &Tensor, u: &Tensor, z: &Tensor) -> (Tensor, Tensor) {
(z.clone(), z.clone())
}
}
struct Sub0_0;
impl Prim2 for Sub0_0 {
fn real_fn(t: &Tensor, u: &Tensor) -> Tensor {
Tensor::scalar(t.as_scalar() - u.as_scalar())
}
fn grad_fn(t: &Tensor, u: &Tensor, z: &Tensor) -> (Tensor, Tensor) {
(z.clone(), -z.clone())
}
}
struct Mul0_0;
impl Prim2 for Mul0_0 {
fn real_fn(t: &Tensor, u: &Tensor) -> Tensor {
Tensor::scalar(t.as_scalar() * u.as_scalar())
}
fn grad_fn(t: &Tensor, u: &Tensor, z: &Tensor) -> (Tensor, Tensor) {
(u * z, t * z)
}
}
struct Dot1_1;
impl Prim2 for Dot1_1 {
fn real_fn(t: &Tensor, u: &Tensor) -> Tensor {
assert!(t.is_rank(1));
assert!(u.is_rank(1));
let mut acc = 0.0;
for (te, ue) in t.elements().zip(u.elements()) {
acc += te.as_scalar() * ue.as_scalar();
}
Tensor::scalar(acc)
}
fn grad_fn(t: &Tensor, u: &Tensor, z: &Tensor) -> (Tensor, Tensor) {
(u * z, t * z)
}
}
macro_rules! ext1 {
($ext:ident, $prim:ident, $n:expr) => {
struct $ext;
impl Prim1 for $ext {
fn real_fn(t: &Tensor) -> Tensor {
if t.is_rank($n) {
return $prim::real_fn(t);
}
return t.map($ext::real_fn);
}
fn grad_fn(t: &Tensor, z: &Tensor) -> Tensor {
if t.is_rank($n) {
return $prim::grad_fn(t, z);
}
return tmap2($ext::grad_fn, t, z);
}
}
};
}
macro_rules! ext2 {
($ext:ident, $prim:ident, $n:expr, $m:expr) => {
struct $ext;
impl Prim2 for $ext {
fn real_fn(t: &Tensor, u: &Tensor) -> Tensor {
if t.is_rank($n) && u.is_rank($m) {
return $prim::real_fn(t, u);
}
if t.is_rank($n) {
return u.map(|eu| $ext::real_fn(t, eu));
}
if u.is_rank($m) {
return t.map(|et| $ext::real_fn(et, u));
}
if t.len() == u.len() {
return tmap2($ext::real_fn, t, u);
}
if t.rank_gt(u) {
return t.map(|et| $ext::real_fn(et, u));
}
if u.rank_gt(t) {
return u.map(|eu| $ext::real_fn(t, eu));
}
panic!()
}
fn grad_fn(t: &Tensor, u: &Tensor, z: &Tensor) -> (Tensor, Tensor) {
if t.is_rank($n) && u.is_rank($m) {
return $prim::grad_fn(t, u, z);
}
if t.is_rank($n) {
return desc_u_grad($ext::grad_fn, t, u, z);
}
if u.is_rank($m) {
return desc_t_grad($ext::grad_fn, t, u, z);
}
if t.len() == u.len() {
return tmap_ext($ext::grad_fn, t, u, z);
}
if t.rank_gt(u) {
return desc_t_grad($ext::grad_fn, t, u, z);
}
if u.rank_gt(t) {
return desc_u_grad($ext::grad_fn, t, u, z);
}
panic!()
}
}
};
}
ext1!(Sum, Sum1, 1);
ext1!(Sqrt_, Sqrt0, 0);
ext1!(Neg, Neg0, 0);
ext2!(Add, Add0_0, 0, 0);
ext2!(Sub, Sub0_0, 0, 0);
ext2!(Mul, Mul0_0, 0, 0);
ext2!(Dot, Dot1_1, 1, 1);
fn sqr<T: Clone>(x: T) -> T
where
T: std::ops::Mul<Output = T>,
{
x.clone() * x
}
trait Sqrt {
type Output;
fn sqrt(self) -> Self::Output;
}
fn sqrt<T: Sqrt>(x: T) -> T::Output {
x.sqrt()
}
fn dot<T: DifferentiableArithmetic>(
a: T,
b: impl Into<T>,
) -> <T as DifferentiableArithmetic>::Output {
T::dot(a, b.into())
}
fn sum<T: DifferentiableArithmetic>(a: T) -> <T as DifferentiableArithmetic>::Output {
T::sum(a)
}
trait DifferentiableArithmetic
where
Self: std::ops::Add<f64, Output = <Self as DifferentiableArithmetic>::Output>,
Self: std::ops::Mul<f64, Output = <Self as DifferentiableArithmetic>::Output>,
{
type Output: DifferentiableArithmetic;
fn dot(a: Self, b: Self) -> <Self as DifferentiableArithmetic>::Output;
fn sum(a: Self) -> <Self as DifferentiableArithmetic>::Output;
}
impl DifferentiableArithmetic for Tensor {
type Output = Tensor;
fn dot(a: Self, b: Self) -> Tensor {
Dot::real_fn(&a, &b)
}
fn sum(a: Self) -> Tensor {
Sum::real_fn(&a)
}
}
impl DifferentiableArithmetic for &Tensor {
type Output = Tensor;
fn dot(a: Self, b: Self) -> Tensor {
Dot::real_fn(a, b)
}
fn sum(a: Self) -> Tensor {
Sum::real_fn(a)
}
}
impl DifferentiableArithmetic for Dual {
type Output = Dual;
fn dot(a: Self, b: Self) -> Dual {
Dot::apply(a, b)
}
fn sum(a: Self) -> Dual {
Sum::apply(a)
}
}
impl DifferentiableArithmetic for &Dual {
type Output = Dual;
fn dot(a: Self, b: Self) -> Dual {
Dot::apply(a.clone(), b.clone())
}
fn sum(a: Self) -> Dual {
Sum::apply(a.clone())
}
}
impl std::ops::Add<f64> for Tensor {
type Output = Tensor;
fn add(self, x: f64) -> Tensor {
self + Tensor::scalar(x)
}
}
impl std::ops::Add<f64> for &Tensor {
type Output = Tensor;
fn add(self, x: f64) -> Tensor {
self + &Tensor::scalar(x)
}
}
impl std::ops::Mul<f64> for Tensor {
type Output = Tensor;
fn mul(self, x: f64) -> Tensor {
self * Tensor::scalar(x)
}
}
impl std::ops::Mul<f64> for &Tensor {
type Output = Tensor;
fn mul(self, x: f64) -> Tensor {
self * &Tensor::scalar(x)
}
}
impl std::ops::Add<f64> for Dual {
type Output = Dual;
fn add(self, x: f64) -> Dual {
self + Dual::end(Tensor::scalar(x))
}
}
impl std::ops::Add<f64> for &Dual {
type Output = Dual;
fn add(self, x: f64) -> Dual {
self + &Dual::end(Tensor::scalar(x))
}
}
impl std::ops::Mul<f64> for Dual {
type Output = Dual;
fn mul(self, x: f64) -> Dual {
self * Dual::end(Tensor::scalar(x))
}
}
impl std::ops::Mul<f64> for &Dual {
type Output = Dual;
fn mul(self, x: f64) -> Dual {
self * &Dual::end(Tensor::scalar(x))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment