Skip to content

Instantly share code, notes, and snippets.

@lissahyacinth
Created March 25, 2020 14:39
Show Gist options
  • Save lissahyacinth/9379c3f10a1d8ac816a3889c28d825ef to your computer and use it in GitHub Desktop.
Save lissahyacinth/9379c3f10a1d8ac816a3889c28d825ef to your computer and use it in GitHub Desktop.
#![cfg(feature = "native")] // required for data i/o
extern crate coaster_blas as co_blas;
extern crate coaster as co;
use std::fmt;
use crate::co::backend::{Backend, IBackend};
use crate::co::framework::{IFramework};
use crate::co::plugin::numeric_helpers::{cast, Float, NumCast};
use crate::co::tensor::{SharedTensor,ITensorDesc};
use crate::co_blas::plugin::*;
use crate::co_blas::transpose::Transpose;
#[cfg(feature = "native")]
use crate::co::frameworks::Native;
#[cfg(feature = "cuda")]
use crate::co::frameworks::Cuda;
#[cfg(feature = "native")]
fn get_native_backend() -> Backend<Native> {
Backend::<Native>::default().unwrap()
}
#[cfg(feature = "cuda")]
fn get_cuda_backend() -> Backend<Cuda> {
Backend::<Cuda>::default().unwrap()
}
// TODO reuse the coaster-nn methods
pub fn write_to_tensor<T>(xs: &mut SharedTensor<T>, data: &[f64])
where T: ::std::marker::Copy + NumCast {
assert_eq!(xs.desc().size(), data.len());
let native = get_native_backend();
let native_dev = native.device();
{
let mem = xs.write_only(native_dev).unwrap();
let mem_buffer = mem.as_mut_slice::<T>();
for (i, x) in data.iter().enumerate() {
mem_buffer[i] = cast::<_, T>(*x).unwrap();
}
}
}
// TODO reuse the coaster-nn methods
pub fn tensor_assert_eq<T>(xs: &SharedTensor<T>, data: &[f64], epsilon_mul: f64)
where T: Float + fmt::Debug + PartialEq + NumCast {
let e = 0. * epsilon_mul;
let native = get_native_backend();
let native_dev = native.device();
let mem = xs.read(native_dev).unwrap();
let mem_slice = mem.as_slice::<T>();
assert_eq!(mem_slice.len(), data.len());
for (x1, x2) in mem_slice.iter().zip(data.iter()) {
let x1_t = cast::<_, f64>(*x1).unwrap();
let diff = (x1_t - x2).abs();
let max_diff = e * (x1_t.abs() + x2.abs()) * 0.5;
if (x1_t - x2).abs() > e * (x1_t.abs() + x2.abs()) * 0.5 {
println!("Results differ: {:?} != {:?} ({:.2?} in {:?} and {:?}",
x1_t, x2, diff / max_diff, mem_slice, data);
assert!(false);
}
}
}
pub fn test_asum<T, F>(backend: &Backend<F>)
where T: Float + fmt::Debug,
F: IFramework,
Backend<F>: Asum<T> + IBackend {
let mut x = SharedTensor::<T>::new(&[3]);
let mut result = SharedTensor::<T>::new(&[1]);
write_to_tensor(&mut x, &[1.0, -2.0, 3.0]);
(*backend).asum(&x, &mut result).unwrap();
tensor_assert_eq(&result, &[6.0], 0.);
}
pub fn test_nrm2<T, F>(backend: &Backend<F>)
where T: Float + fmt::Debug,
F: IFramework,
Backend<F>: Nrm2<T> + IBackend {
let mut x = SharedTensor::<T>::new(&[3]);
let mut result = SharedTensor::<T>::new(&[1]);
write_to_tensor(&mut x, &[1., 2., 2.]);
backend.nrm2(&x, &mut result).unwrap();
tensor_assert_eq(&result, &[3.0], 0.);
}
pub fn test_copy<T, F>(backend: &Backend<F>)
where T: Float + fmt::Debug,
F: IFramework,
Backend<F>: Copy<T> + IBackend {
let mut x = SharedTensor::<T>::new(&[3]);
let mut y = SharedTensor::<T>::new(&[3]);
write_to_tensor(&mut x, &[1., 2., 3.]);
backend.copy(&x, &mut y).unwrap();
tensor_assert_eq(&y, &[1.0, 2.0, 3.0], 0.);
}
pub fn test_dot<T, F>(backend: &Backend<F>)
where T: Float + fmt::Debug,
F: IFramework,
Backend<F>: Dot<T> + IBackend {
let mut x = SharedTensor::<T>::new(&[3]);
let mut y = SharedTensor::<T>::new(&[3]);
let mut result = SharedTensor::<T>::new(&[1]);
write_to_tensor(&mut x, &[1., 2., 3.]);
write_to_tensor(&mut y, &[1., 2., 3.]);
backend.dot(&x, &y, &mut result).unwrap();
tensor_assert_eq(&result, &[14.0], 0.);
}
pub fn test_axpy<T, F>(backend: &Backend<F>)
where T: Float + fmt::Debug,
F: IFramework,
Backend<F>: Axpy<T> + IBackend {
let mut a = SharedTensor::<T>::new(&[1]);
let mut x = SharedTensor::<T>::new(&[3]);
let mut y = SharedTensor::<T>::new(&[3]);
write_to_tensor(&mut a, &[2.]);
write_to_tensor(&mut x, &[1., 2., 3.]);
write_to_tensor(&mut y, &[1., 2., 3.]);
backend.axpy(&a, &x, &mut y).unwrap();
backend.synchronize().unwrap();
tensor_assert_eq(&y, &[3.0, 6.0, 9.0], 0.);
}
fn main() {
let x = get_cuda_backend();
test_asum(&x);
test_nrm2(&x);
test_axpy(&x);
test_dot(&x);
test_copy(&x);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment