Last active
August 16, 2023 13:48
-
-
Save mbillingr/567d8c9dd3b01fe06e3995265e3a8be7 to your computer and use it in GitHub Desktop.
Automatic differentiation on tensors prototype
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::HashMap; | |
use std::rc::Rc; | |
fn main() { | |
let x = &Tensor::new(vec![ | |
vec![0.0, 0.0], | |
vec![1.0, 0.0], | |
vec![2.0, 0.0], | |
vec![0.0, 1.0], | |
vec![1.0, 1.0], | |
vec![2.0, 1.0], | |
vec![0.0, 2.0], | |
vec![1.0, 2.0], | |
vec![2.0, 2.0], | |
]); | |
let y = &Tensor::new(vec![-5.0, 5.0, 15.0, -7.0, 3.0, 13.0, -9.0, 1.0, 11.0]); | |
println!("X: {}", x); | |
println!("Y: {}", y); | |
let w0 = Tensor::new(vec![0.0, 0.0]); | |
let b0 = Tensor::new(0.0); | |
//println!("{:?}", Sum::real_fn(&sqr(dot(&w, &x) + b - y))); | |
for t in gradient_descent( | |
&|t| { | |
let w = t[0].clone(); | |
let b = t[1].clone(); | |
sum(sqr(dot(w, x) + b - y)) | |
}, | |
vec![w0, b0], | |
1000, | |
0.01, | |
) { | |
println!("{}", t); | |
} | |
} | |
fn gradient_descent( | |
f: &impl Fn(&[Dual]) -> Dual, | |
mut theta: Vec<Tensor>, | |
n: usize, | |
alpha: f64, | |
) -> Vec<Tensor> { | |
for _ in 0..n { | |
let grd = gradient_of(f, &theta); | |
for (t, g) in theta.iter_mut().zip(grd) { | |
*t = t as &_ - &(g * alpha); | |
} | |
} | |
theta | |
} | |
fn gradient_of(f: &impl Fn(&[Dual]) -> Dual, theta: &[Tensor]) -> Vec<Tensor> { | |
let wrt: Vec<Dual> = theta.iter().cloned().map(Dual::end).collect(); | |
let g = f(&wrt); | |
let mut sigma = g.grad().compute(&Tensor::scalar(1.0), HashMap::new()); | |
wrt.iter().map(|d| sigma.remove(&d.id()).unwrap()).collect() | |
} | |
#[derive(Debug, Clone)] | |
struct Tensor { | |
store: Rc<Vec<f64>>, | |
shape: Vec<usize>, | |
offset: usize, | |
} | |
#[derive(Debug, Clone)] | |
struct Dual(Rc<(Tensor, Chain)>); | |
#[derive(Clone)] | |
enum Chain { | |
End(usize), | |
Prim1(Dual, fn(&Tensor, &Tensor) -> Tensor), | |
Prim2( | |
Dual, | |
Dual, | |
fn(&Tensor, &Tensor, &Tensor) -> (Tensor, Tensor), | |
), | |
} | |
trait Prim1 { | |
fn real_fn(t: &Tensor) -> Tensor; | |
fn grad_fn(t: &Tensor, z: &Tensor) -> Tensor; | |
fn apply(d: impl Into<Dual>) -> Dual { | |
let d = d.into(); | |
Dual::new(Self::real_fn(&d.real()), Chain::Prim1(d, Self::grad_fn)) | |
} | |
} | |
trait Prim2 { | |
fn real_fn(t: &Tensor, u: &Tensor) -> Tensor; | |
fn grad_fn(t: &Tensor, u: &Tensor, z: &Tensor) -> (Tensor, Tensor); | |
fn apply(da: impl Into<Dual>, db: impl Into<Dual>) -> Dual { | |
let (da, db) = (da.into(), db.into()); | |
Dual::new( | |
Self::real_fn(&da.real(), &db.real()), | |
Chain::Prim2(da, db, Self::grad_fn), | |
) | |
} | |
} | |
// Core interface | |
impl Dual { | |
fn end(t: Tensor) -> Self { | |
let mut tmp = Rc::new((t, Chain::End(0))); | |
let id = &*tmp as *const _ as usize; | |
Rc::make_mut(&mut tmp).1 = Chain::End(id); | |
Dual(tmp) | |
} | |
fn new(t: Tensor, chain: Chain) -> Self { | |
Dual(Rc::new((t, chain))) | |
} | |
fn id(&self) -> usize { | |
&*self.0 as *const _ as usize | |
} | |
fn real(&self) -> &Tensor { | |
&self.0 .0 | |
} | |
fn grad(&self) -> &Chain { | |
&self.0 .1 | |
} | |
} | |
impl From<&Dual> for Dual { | |
fn from(d: &Dual) -> Dual { | |
d.clone() | |
} | |
} | |
impl From<Tensor> for Dual { | |
fn from(d: Tensor) -> Dual { | |
Dual::end(d) | |
} | |
} | |
impl From<&Tensor> for Dual { | |
fn from(d: &Tensor) -> Dual { | |
Dual::end(d.clone()) | |
} | |
} | |
impl Tensor { | |
fn zero() -> Self { | |
Self::scalar(0.0) | |
} | |
fn scalar(x: f64) -> Self { | |
Tensor { | |
store: Rc::new(vec![x]), | |
shape: vec![], | |
offset: 0, | |
} | |
} | |
fn new(x: impl Into<Tensor>) -> Self { | |
x.into() | |
} | |
fn is_scalar(&self) -> bool { | |
self.shape.is_empty() | |
} | |
fn len(&self) -> usize { | |
if self.is_scalar() { | |
todo!("should scalars have lentgth?") | |
} | |
self.shape[0] | |
} | |
fn is_rank(&self, r: usize) -> bool { | |
self.shape.len() == r | |
} | |
fn rank_gt(&self, other: &Tensor) -> bool { | |
self.shape.len() > other.shape.len() | |
} | |
fn as_scalar(&self) -> f64 { | |
if !self.is_scalar() { | |
panic!("not a scalar") | |
} | |
self.store[self.offset] | |
} | |
fn elements(&self) -> impl Iterator<Item = Tensor> + '_ { | |
let stride = self.stride(); | |
(0..self.len()).map(move |i| Tensor { | |
store: self.store.clone(), | |
shape: self.shape[1..].to_vec(), | |
offset: self.offset + i * stride, | |
}) | |
} | |
fn map(&self, f: impl Fn(&Tensor) -> Tensor) -> Self { | |
if self.is_scalar() { | |
f(self) | |
} else { | |
let v: Vec<_> = self.elements().map(|t| f(&t)).collect(); | |
Tensor::new(v) | |
} | |
} | |
fn stride(&self) -> usize { | |
let mut stride = 1; | |
for s in &self.shape[1..] { | |
stride *= s; | |
} | |
stride | |
} | |
fn size(&self) -> usize { | |
let mut size = 1; | |
for s in &self.shape { | |
size *= s; | |
} | |
size | |
} | |
} | |
impl std::fmt::Display for Tensor { | |
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | |
if self.is_scalar() { | |
self.as_scalar().fmt(f) | |
} else { | |
write!(f, "[")?; | |
for (i, e) in self.elements().enumerate() { | |
if i > 0 { | |
write!(f, " ")?; | |
} | |
e.fmt(f)?; | |
} | |
write!(f, "]") | |
} | |
} | |
} | |
impl From<f64> for Tensor { | |
fn from(s: f64) -> Self { | |
Tensor::scalar(s) | |
} | |
} | |
impl<T: Into<Tensor>> From<Vec<T>> for Tensor { | |
fn from(v: Vec<T>) -> Self { | |
let ts = v.into_iter().map(Into::into).collect::<Vec<_>>(); | |
let mut shape = vec![ts.len()]; | |
shape.extend(&ts[0].shape); | |
let mut store = vec![]; | |
for t in ts { | |
store.extend(&t.store[t.offset..t.offset + t.size()]); | |
} | |
Tensor { | |
store: Rc::new(store), | |
shape, | |
offset: 0, | |
} | |
} | |
} | |
impl FromIterator<Tensor> for Tensor { | |
fn from_iter<T: IntoIterator<Item = Tensor>>(it: T) -> Self { | |
let v: Vec<_> = it.into_iter().collect(); | |
v.into() | |
} | |
} | |
fn tmap2(f: impl Fn(&Tensor, &Tensor) -> Tensor, ta: &Tensor, tb: &Tensor) -> Tensor { | |
ta.elements() | |
.zip(tb.elements()) | |
.map(|(a, b)| f(&a, &b)) | |
.collect() | |
} | |
// Internal stuff | |
impl std::fmt::Debug for Chain { | |
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | |
match self { | |
Chain::End(d) => write!(f, "({:?})", d), | |
Chain::Prim1(d, _) => write!(f, "Prim1({:?}, ...)", d), | |
Chain::Prim2(da, db, _) => write!(f, "Prim2({:?}, {:?}, ...)", da, db), | |
} | |
} | |
} | |
impl Chain { | |
fn compute(&self, z: &Tensor, mut sigma: HashMap<usize, Tensor>) -> HashMap<usize, Tensor> { | |
match self { | |
Chain::End(id) => { | |
match sigma.get(id) { | |
None => sigma.insert(*id, z.clone()), | |
Some(g) => sigma.insert(*id, z + g), | |
}; | |
sigma | |
} | |
Chain::Prim1(d, grad_fn) => { | |
let ga = grad_fn(d.real(), z); | |
d.grad().compute(&ga, sigma) | |
} | |
Chain::Prim2(da, db, grad_fn) => { | |
let (ga, gb) = grad_fn(da.real(), db.real(), z); | |
let sigma = da.grad().compute(&ga, sigma); | |
db.grad().compute(&gb, sigma) | |
} | |
} | |
} | |
} | |
impl std::ops::Add for &Dual { | |
type Output = Dual; | |
fn add(self, other: &Dual) -> Self::Output { | |
Dual::new( | |
self.real() + other.real(), | |
Chain::Prim2(self.clone(), other.clone(), Add::grad_fn), | |
) | |
} | |
} | |
impl<T: Into<Dual>> std::ops::Add<T> for Dual { | |
type Output = Dual; | |
fn add(self, other: T) -> Self::Output { | |
let other = other.into(); | |
Dual::new( | |
self.real() + other.real(), | |
Chain::Prim2(self, other, Add::grad_fn), | |
) | |
} | |
} | |
impl std::ops::Sub for &Dual { | |
type Output = Dual; | |
fn sub(self, other: &Dual) -> Self::Output { | |
Dual::new( | |
self.real() - other.real(), | |
Chain::Prim2(self.clone(), other.clone(), Sub::grad_fn), | |
) | |
} | |
} | |
impl<T: Into<Dual>> std::ops::Sub<T> for Dual { | |
type Output = Dual; | |
fn sub(self, other: T) -> Self::Output { | |
let other = other.into(); | |
Dual::new( | |
self.real() - other.real(), | |
Chain::Prim2(self, other, Sub::grad_fn), | |
) | |
} | |
} | |
impl std::ops::Mul for Dual { | |
type Output = Dual; | |
fn mul(self, other: Dual) -> Self::Output { | |
Dual::new( | |
self.real() * other.real(), | |
Chain::Prim2(self, other, Mul::grad_fn), | |
) | |
} | |
} | |
impl std::ops::Mul for &Dual { | |
type Output = Dual; | |
fn mul(self, other: &Dual) -> Self::Output { | |
Dual::new( | |
self.real() * other.real(), | |
Chain::Prim2(self.clone(), other.clone(), Mul::grad_fn), | |
) | |
} | |
} | |
impl std::ops::Neg for Dual { | |
type Output = Dual; | |
fn neg(self) -> Self::Output { | |
Dual::new(-self.real(), Chain::Prim1(self, Neg::grad_fn)) | |
} | |
} | |
impl Sqrt for Dual { | |
type Output = Dual; | |
fn sqrt(self) -> Self::Output { | |
Dual::new( | |
self.real().sqrt(), | |
Chain::Prim1(self.clone(), Sqrt_::grad_fn), | |
) | |
} | |
} | |
impl std::ops::Add for Tensor { | |
type Output = Tensor; | |
fn add(self, other: Tensor) -> Self::Output { | |
Add::real_fn(&self, &other) | |
} | |
} | |
impl std::ops::Add for &Tensor { | |
type Output = Tensor; | |
fn add(self, other: &Tensor) -> Self::Output { | |
Add::real_fn(self, other) | |
} | |
} | |
impl std::ops::Sub for Tensor { | |
type Output = Tensor; | |
fn sub(self, other: Tensor) -> Self::Output { | |
Sub::real_fn(&self, &other) | |
} | |
} | |
impl std::ops::Sub for &Tensor { | |
type Output = Tensor; | |
fn sub(self, other: &Tensor) -> Self::Output { | |
Sub::real_fn(self, other) | |
} | |
} | |
impl std::ops::Mul for Tensor { | |
type Output = Tensor; | |
fn mul(self, other: Tensor) -> Self::Output { | |
Mul::real_fn(&self, &other) | |
} | |
} | |
impl std::ops::Mul for &Tensor { | |
type Output = Tensor; | |
fn mul(self, other: &Tensor) -> Self::Output { | |
Mul::real_fn(self, other) | |
} | |
} | |
impl std::ops::Neg for Tensor { | |
type Output = Tensor; | |
fn neg(self) -> Tensor { | |
Neg::real_fn(&self) | |
} | |
} | |
impl std::ops::Neg for &Tensor { | |
type Output = Tensor; | |
fn neg(self) -> Tensor { | |
Neg::real_fn(self) | |
} | |
} | |
impl Sqrt for &Tensor { | |
type Output = Tensor; | |
fn sqrt(self) -> Self::Output { | |
Sqrt_::real_fn(self) | |
} | |
} | |
fn tmap_ext( | |
f: impl Fn(&Tensor, &Tensor, &Tensor) -> (Tensor, Tensor), | |
ta: &Tensor, | |
tb: &Tensor, | |
tc: &Tensor, | |
) -> (Tensor, Tensor) { | |
let mut gu: Vec<Tensor> = Vec::with_capacity(ta.len()); | |
let mut gt: Vec<Tensor> = Vec::with_capacity(ta.len()); | |
for ((a, b), c) in ta.elements().zip(tb.elements()).zip(tc.elements()) { | |
let (gti, gui) = f(&a, &b, &c); | |
gt.push(gti); | |
gu.push(gui); | |
} | |
(gt.into(), gu.into()) | |
} | |
fn desc_u_grad( | |
g: impl Fn(&Tensor, &Tensor, &Tensor) -> (Tensor, Tensor), | |
t: &Tensor, | |
u: &Tensor, | |
z: &Tensor, | |
) -> (Tensor, Tensor) { | |
let mut gt = Tensor::zero(); | |
let mut gu: Vec<Tensor> = Vec::with_capacity(u.len()); | |
for (ui, zi) in u.elements().zip(z.elements()) { | |
let (gti, gui) = g(t, &ui, &zi); | |
gt = > + >i; | |
gu.push(gui); | |
} | |
(gt.into(), gu.into()) | |
} | |
fn desc_t_grad( | |
g: impl Fn(&Tensor, &Tensor, &Tensor) -> (Tensor, Tensor), | |
t: &Tensor, | |
u: &Tensor, | |
z: &Tensor, | |
) -> (Tensor, Tensor) { | |
let mut gt: Vec<Tensor> = Vec::with_capacity(t.len()); | |
let mut gu = Tensor::zero(); | |
for (ti, zi) in t.elements().zip(z.elements()) { | |
let (gti, gui) = g(&ti, u, &zi); | |
gt.push(gti); | |
gu = &gu + &gui; | |
} | |
(gt.into(), gu.into()) | |
} | |
// Primitive definitions | |
struct Sum1; | |
impl Prim1 for Sum1 { | |
fn real_fn(t: &Tensor) -> Tensor { | |
assert!(t.is_rank(1)); | |
let mut acc = 0.0; | |
for te in t.elements() { | |
acc += te.as_scalar(); | |
} | |
Tensor::scalar(acc) | |
} | |
fn grad_fn(t: &Tensor, z: &Tensor) -> Tensor { | |
t.map(|_| z.clone()) | |
} | |
} | |
struct Sqrt0; | |
impl Prim1 for Sqrt0 { | |
fn real_fn(t: &Tensor) -> Tensor { | |
Tensor::scalar(t.as_scalar().sqrt()) | |
} | |
fn grad_fn(t: &Tensor, z: &Tensor) -> Tensor { | |
z * &Tensor::scalar((2.0 * t.as_scalar().sqrt())) | |
} | |
} | |
struct Neg0; | |
impl Prim1 for Neg0 { | |
fn real_fn(t: &Tensor) -> Tensor { | |
Tensor::scalar(-t.as_scalar()) | |
} | |
fn grad_fn(t: &Tensor, z: &Tensor) -> Tensor { | |
-z.clone() | |
} | |
} | |
struct Add0_0; | |
impl Prim2 for Add0_0 { | |
fn real_fn(t: &Tensor, u: &Tensor) -> Tensor { | |
Tensor::scalar(t.as_scalar() + u.as_scalar()) | |
} | |
fn grad_fn(t: &Tensor, u: &Tensor, z: &Tensor) -> (Tensor, Tensor) { | |
(z.clone(), z.clone()) | |
} | |
} | |
struct Sub0_0; | |
impl Prim2 for Sub0_0 { | |
fn real_fn(t: &Tensor, u: &Tensor) -> Tensor { | |
Tensor::scalar(t.as_scalar() - u.as_scalar()) | |
} | |
fn grad_fn(t: &Tensor, u: &Tensor, z: &Tensor) -> (Tensor, Tensor) { | |
(z.clone(), -z.clone()) | |
} | |
} | |
struct Mul0_0; | |
impl Prim2 for Mul0_0 { | |
fn real_fn(t: &Tensor, u: &Tensor) -> Tensor { | |
Tensor::scalar(t.as_scalar() * u.as_scalar()) | |
} | |
fn grad_fn(t: &Tensor, u: &Tensor, z: &Tensor) -> (Tensor, Tensor) { | |
(u * z, t * z) | |
} | |
} | |
struct Dot1_1; | |
impl Prim2 for Dot1_1 { | |
fn real_fn(t: &Tensor, u: &Tensor) -> Tensor { | |
assert!(t.is_rank(1)); | |
assert!(u.is_rank(1)); | |
let mut acc = 0.0; | |
for (te, ue) in t.elements().zip(u.elements()) { | |
acc += te.as_scalar() * ue.as_scalar(); | |
} | |
Tensor::scalar(acc) | |
} | |
fn grad_fn(t: &Tensor, u: &Tensor, z: &Tensor) -> (Tensor, Tensor) { | |
(u * z, t * z) | |
} | |
} | |
macro_rules! ext1 { | |
($ext:ident, $prim:ident, $n:expr) => { | |
struct $ext; | |
impl Prim1 for $ext { | |
fn real_fn(t: &Tensor) -> Tensor { | |
if t.is_rank($n) { | |
return $prim::real_fn(t); | |
} | |
return t.map($ext::real_fn); | |
} | |
fn grad_fn(t: &Tensor, z: &Tensor) -> Tensor { | |
if t.is_rank($n) { | |
return $prim::grad_fn(t, z); | |
} | |
return tmap2($ext::grad_fn, t, z); | |
} | |
} | |
}; | |
} | |
macro_rules! ext2 { | |
($ext:ident, $prim:ident, $n:expr, $m:expr) => { | |
struct $ext; | |
impl Prim2 for $ext { | |
fn real_fn(t: &Tensor, u: &Tensor) -> Tensor { | |
if t.is_rank($n) && u.is_rank($m) { | |
return $prim::real_fn(t, u); | |
} | |
if t.is_rank($n) { | |
return u.map(|eu| $ext::real_fn(t, eu)); | |
} | |
if u.is_rank($m) { | |
return t.map(|et| $ext::real_fn(et, u)); | |
} | |
if t.len() == u.len() { | |
return tmap2($ext::real_fn, t, u); | |
} | |
if t.rank_gt(u) { | |
return t.map(|et| $ext::real_fn(et, u)); | |
} | |
if u.rank_gt(t) { | |
return u.map(|eu| $ext::real_fn(t, eu)); | |
} | |
panic!() | |
} | |
fn grad_fn(t: &Tensor, u: &Tensor, z: &Tensor) -> (Tensor, Tensor) { | |
if t.is_rank($n) && u.is_rank($m) { | |
return $prim::grad_fn(t, u, z); | |
} | |
if t.is_rank($n) { | |
return desc_u_grad($ext::grad_fn, t, u, z); | |
} | |
if u.is_rank($m) { | |
return desc_t_grad($ext::grad_fn, t, u, z); | |
} | |
if t.len() == u.len() { | |
return tmap_ext($ext::grad_fn, t, u, z); | |
} | |
if t.rank_gt(u) { | |
return desc_t_grad($ext::grad_fn, t, u, z); | |
} | |
if u.rank_gt(t) { | |
return desc_u_grad($ext::grad_fn, t, u, z); | |
} | |
panic!() | |
} | |
} | |
}; | |
} | |
ext1!(Sum, Sum1, 1); | |
ext1!(Sqrt_, Sqrt0, 0); | |
ext1!(Neg, Neg0, 0); | |
ext2!(Add, Add0_0, 0, 0); | |
ext2!(Sub, Sub0_0, 0, 0); | |
ext2!(Mul, Mul0_0, 0, 0); | |
ext2!(Dot, Dot1_1, 1, 1); | |
fn sqr<T: Clone>(x: T) -> T | |
where | |
T: std::ops::Mul<Output = T>, | |
{ | |
x.clone() * x | |
} | |
trait Sqrt { | |
type Output; | |
fn sqrt(self) -> Self::Output; | |
} | |
fn sqrt<T: Sqrt>(x: T) -> T::Output { | |
x.sqrt() | |
} | |
fn dot<T: DifferentiableArithmetic>( | |
a: T, | |
b: impl Into<T>, | |
) -> <T as DifferentiableArithmetic>::Output { | |
T::dot(a, b.into()) | |
} | |
fn sum<T: DifferentiableArithmetic>(a: T) -> <T as DifferentiableArithmetic>::Output { | |
T::sum(a) | |
} | |
trait DifferentiableArithmetic | |
where | |
Self: std::ops::Add<f64, Output = <Self as DifferentiableArithmetic>::Output>, | |
Self: std::ops::Mul<f64, Output = <Self as DifferentiableArithmetic>::Output>, | |
{ | |
type Output: DifferentiableArithmetic; | |
fn dot(a: Self, b: Self) -> <Self as DifferentiableArithmetic>::Output; | |
fn sum(a: Self) -> <Self as DifferentiableArithmetic>::Output; | |
} | |
impl DifferentiableArithmetic for Tensor { | |
type Output = Tensor; | |
fn dot(a: Self, b: Self) -> Tensor { | |
Dot::real_fn(&a, &b) | |
} | |
fn sum(a: Self) -> Tensor { | |
Sum::real_fn(&a) | |
} | |
} | |
impl DifferentiableArithmetic for &Tensor { | |
type Output = Tensor; | |
fn dot(a: Self, b: Self) -> Tensor { | |
Dot::real_fn(a, b) | |
} | |
fn sum(a: Self) -> Tensor { | |
Sum::real_fn(a) | |
} | |
} | |
impl DifferentiableArithmetic for Dual { | |
type Output = Dual; | |
fn dot(a: Self, b: Self) -> Dual { | |
Dot::apply(a, b) | |
} | |
fn sum(a: Self) -> Dual { | |
Sum::apply(a) | |
} | |
} | |
impl DifferentiableArithmetic for &Dual { | |
type Output = Dual; | |
fn dot(a: Self, b: Self) -> Dual { | |
Dot::apply(a.clone(), b.clone()) | |
} | |
fn sum(a: Self) -> Dual { | |
Sum::apply(a.clone()) | |
} | |
} | |
impl std::ops::Add<f64> for Tensor { | |
type Output = Tensor; | |
fn add(self, x: f64) -> Tensor { | |
self + Tensor::scalar(x) | |
} | |
} | |
impl std::ops::Add<f64> for &Tensor { | |
type Output = Tensor; | |
fn add(self, x: f64) -> Tensor { | |
self + &Tensor::scalar(x) | |
} | |
} | |
impl std::ops::Mul<f64> for Tensor { | |
type Output = Tensor; | |
fn mul(self, x: f64) -> Tensor { | |
self * Tensor::scalar(x) | |
} | |
} | |
impl std::ops::Mul<f64> for &Tensor { | |
type Output = Tensor; | |
fn mul(self, x: f64) -> Tensor { | |
self * &Tensor::scalar(x) | |
} | |
} | |
impl std::ops::Add<f64> for Dual { | |
type Output = Dual; | |
fn add(self, x: f64) -> Dual { | |
self + Dual::end(Tensor::scalar(x)) | |
} | |
} | |
impl std::ops::Add<f64> for &Dual { | |
type Output = Dual; | |
fn add(self, x: f64) -> Dual { | |
self + &Dual::end(Tensor::scalar(x)) | |
} | |
} | |
impl std::ops::Mul<f64> for Dual { | |
type Output = Dual; | |
fn mul(self, x: f64) -> Dual { | |
self * Dual::end(Tensor::scalar(x)) | |
} | |
} | |
impl std::ops::Mul<f64> for &Dual { | |
type Output = Dual; | |
fn mul(self, x: f64) -> Dual { | |
self * &Dual::end(Tensor::scalar(x)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment