Skip to content

Instantly share code, notes, and snippets.

@jerry73204
Created September 18, 2019 06:36
Show Gist options
  • Save jerry73204/86f6bf3848d6cba4b44f32e33a371802 to your computer and use it in GitHub Desktop.
Save jerry73204/86f6bf3848d6cba4b44f32e33a371802 to your computer and use it in GitHub Desktop.
demo for tch-rs #112
use std::{
marker::PhantomData,
ops::{Add, Shl},
};
use typenum::{consts::*, Double, Sum, Unsigned};
pub struct Tensor<Size>
where
Size: Unsigned,
{
size: usize, // this field is not necessary, but for runtime assertion
_phantom: PhantomData<Size>,
}
impl<Size> Tensor<Size>
where
Size: Unsigned,
{
pub fn new() -> Self {
Self {
size: Size::USIZE,
_phantom: PhantomData,
}
}
pub fn dup(&self) -> Tensor<Double<Size>>
where
Size: Shl<B1>,
Double<Size>: Unsigned,
{
// you can write Tensor::new() instead
let tensor = Tensor::<Double<Size>>::new();
assert_eq!(self.size * 2, Double::<Size>::USIZE);
tensor
}
pub fn concat<RhsSize>(&self, rhs: &Tensor<RhsSize>) -> Tensor<Sum<Size, RhsSize>>
where
Size: Add<RhsSize>,
RhsSize: Unsigned,
Sum<Size, RhsSize>: Unsigned,
{
// you can write Tensor::new() instead
let tensor = Tensor::<Sum<Size, RhsSize>>::new();
assert_eq!(self.size + rhs.size, Sum::<Size, RhsSize>::USIZE);
tensor
}
}
fn main() {
let t1 = Tensor::<U3>::new(); // M
let t2 = Tensor::<U2>::new(); // N
// See if following static assertions compile
// left = N + M
// right = M + N
let _: Tensor<Sum<U2, U3>> = t1.concat(&t2);
// letf = 2 * N + 2 * M
// right = (M + N) * 2
let _: Tensor<Sum<Double<U2>, Double<U3>>> = t1.concat(&t2).dup();
// left = (N + M) * 2
// right = 2 * M + 2 * N
let _: Tensor<Double<Sum<U2, U3>>> = t1.dup().concat(&t2.dup());
dbg!("wheeeeee!");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment