Created
March 25, 2020 14:39
-
-
Save lissahyacinth/9379c3f10a1d8ac816a3889c28d825ef to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#![cfg(feature = "native")] // required for data i/o | |
extern crate coaster_blas as co_blas; | |
extern crate coaster as co; | |
use std::fmt; | |
use crate::co::backend::{Backend, IBackend}; | |
use crate::co::framework::{IFramework}; | |
use crate::co::plugin::numeric_helpers::{cast, Float, NumCast}; | |
use crate::co::tensor::{SharedTensor,ITensorDesc}; | |
use crate::co_blas::plugin::*; | |
use crate::co_blas::transpose::Transpose; | |
#[cfg(feature = "native")] | |
use crate::co::frameworks::Native; | |
#[cfg(feature = "cuda")] | |
use crate::co::frameworks::Cuda; | |
#[cfg(feature = "native")] | |
fn get_native_backend() -> Backend<Native> { | |
Backend::<Native>::default().unwrap() | |
} | |
#[cfg(feature = "cuda")] | |
fn get_cuda_backend() -> Backend<Cuda> { | |
Backend::<Cuda>::default().unwrap() | |
} | |
// TODO reuse the coaster-nn methods | |
pub fn write_to_tensor<T>(xs: &mut SharedTensor<T>, data: &[f64]) | |
where T: ::std::marker::Copy + NumCast { | |
assert_eq!(xs.desc().size(), data.len()); | |
let native = get_native_backend(); | |
let native_dev = native.device(); | |
{ | |
let mem = xs.write_only(native_dev).unwrap(); | |
let mem_buffer = mem.as_mut_slice::<T>(); | |
for (i, x) in data.iter().enumerate() { | |
mem_buffer[i] = cast::<_, T>(*x).unwrap(); | |
} | |
} | |
} | |
// TODO reuse the coaster-nn methods | |
pub fn tensor_assert_eq<T>(xs: &SharedTensor<T>, data: &[f64], epsilon_mul: f64) | |
where T: Float + fmt::Debug + PartialEq + NumCast { | |
let e = 0. * epsilon_mul; | |
let native = get_native_backend(); | |
let native_dev = native.device(); | |
let mem = xs.read(native_dev).unwrap(); | |
let mem_slice = mem.as_slice::<T>(); | |
assert_eq!(mem_slice.len(), data.len()); | |
for (x1, x2) in mem_slice.iter().zip(data.iter()) { | |
let x1_t = cast::<_, f64>(*x1).unwrap(); | |
let diff = (x1_t - x2).abs(); | |
let max_diff = e * (x1_t.abs() + x2.abs()) * 0.5; | |
if (x1_t - x2).abs() > e * (x1_t.abs() + x2.abs()) * 0.5 { | |
println!("Results differ: {:?} != {:?} ({:.2?} in {:?} and {:?}", | |
x1_t, x2, diff / max_diff, mem_slice, data); | |
assert!(false); | |
} | |
} | |
} | |
pub fn test_asum<T, F>(backend: &Backend<F>) | |
where T: Float + fmt::Debug, | |
F: IFramework, | |
Backend<F>: Asum<T> + IBackend { | |
let mut x = SharedTensor::<T>::new(&[3]); | |
let mut result = SharedTensor::<T>::new(&[1]); | |
write_to_tensor(&mut x, &[1.0, -2.0, 3.0]); | |
(*backend).asum(&x, &mut result).unwrap(); | |
tensor_assert_eq(&result, &[6.0], 0.); | |
} | |
pub fn test_nrm2<T, F>(backend: &Backend<F>) | |
where T: Float + fmt::Debug, | |
F: IFramework, | |
Backend<F>: Nrm2<T> + IBackend { | |
let mut x = SharedTensor::<T>::new(&[3]); | |
let mut result = SharedTensor::<T>::new(&[1]); | |
write_to_tensor(&mut x, &[1., 2., 2.]); | |
backend.nrm2(&x, &mut result).unwrap(); | |
tensor_assert_eq(&result, &[3.0], 0.); | |
} | |
pub fn test_copy<T, F>(backend: &Backend<F>) | |
where T: Float + fmt::Debug, | |
F: IFramework, | |
Backend<F>: Copy<T> + IBackend { | |
let mut x = SharedTensor::<T>::new(&[3]); | |
let mut y = SharedTensor::<T>::new(&[3]); | |
write_to_tensor(&mut x, &[1., 2., 3.]); | |
backend.copy(&x, &mut y).unwrap(); | |
tensor_assert_eq(&y, &[1.0, 2.0, 3.0], 0.); | |
} | |
pub fn test_dot<T, F>(backend: &Backend<F>) | |
where T: Float + fmt::Debug, | |
F: IFramework, | |
Backend<F>: Dot<T> + IBackend { | |
let mut x = SharedTensor::<T>::new(&[3]); | |
let mut y = SharedTensor::<T>::new(&[3]); | |
let mut result = SharedTensor::<T>::new(&[1]); | |
write_to_tensor(&mut x, &[1., 2., 3.]); | |
write_to_tensor(&mut y, &[1., 2., 3.]); | |
backend.dot(&x, &y, &mut result).unwrap(); | |
tensor_assert_eq(&result, &[14.0], 0.); | |
} | |
pub fn test_axpy<T, F>(backend: &Backend<F>) | |
where T: Float + fmt::Debug, | |
F: IFramework, | |
Backend<F>: Axpy<T> + IBackend { | |
let mut a = SharedTensor::<T>::new(&[1]); | |
let mut x = SharedTensor::<T>::new(&[3]); | |
let mut y = SharedTensor::<T>::new(&[3]); | |
write_to_tensor(&mut a, &[2.]); | |
write_to_tensor(&mut x, &[1., 2., 3.]); | |
write_to_tensor(&mut y, &[1., 2., 3.]); | |
backend.axpy(&a, &x, &mut y).unwrap(); | |
backend.synchronize().unwrap(); | |
tensor_assert_eq(&y, &[3.0, 6.0, 9.0], 0.); | |
} | |
fn main() { | |
let x = get_cuda_backend(); | |
test_asum(&x); | |
test_nrm2(&x); | |
test_axpy(&x); | |
test_dot(&x); | |
test_copy(&x); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment