Skip to content

Instantly share code, notes, and snippets.

@japaric
Created April 29, 2014 14:14
Show Gist options
  • Save japaric/11401627 to your computer and use it in GitHub Desktop.
Save japaric/11401627 to your computer and use it in GitHub Desktop.
Vector dot product: BLAS accelerated implementation + naive implementation
use std::iter::AdditiveIterator;
use std::num::Zero;
#[link(name = "blas")]
extern {
fn ddot_(N: *int, x: *f64, inc_x: *int, y: *f64, inc_y: *int) -> f64;
fn sdot_(N: *int, x: *f32, inc_x: *int, y: *f32, inc_y: *int) -> f32;
}
trait VectorDot<T> {
fn dot(&self, rhs: &Vec<T>) -> T;
}
// FIXME how to tell the compiler to use the BLAS accelerated implementations
// for f32 and f64, but use this fallback implementation for the rest of cases
impl<
T: Add<T, T> + Mul<T, T> + Zero
> VectorDot<T>
for Vec<T> {
fn dot(&self, rhs: &Vec<T>) -> T {
assert!(self.len() == rhs.len(), "dimension mismatch");
self.as_slice().iter().zip(rhs.as_slice().iter()).map(|(x, y)| x.mul(y)).sum()
}
}
// Error: conflicting implementations for trait VectorDot
//impl
//VectorDot<f32>
//for Vec<f32> {
//fn dot(&self, rhs: &Vec<f32>) -> f32 {
//assert!(self.len() == rhs.len(), "dimension mismatch");
//unsafe {
//sdot_(&(self.len() as int) as *_,
//self.as_ptr(),
//&1 as *_,
//rhs.as_ptr(),
//&1 as *_)
//}
//}
//}
// Error: conflicting implementations for trait VectorDot
//impl
//VectorDot<f64>
//for Vec<f64> {
//fn dot(&self, rhs: &Vec<f64>) -> f64 {
//assert!(self.len() == rhs.len(), "dimension mismatch");
//unsafe {
//ddot_(&(self.len() as int) as *_,
//self.as_ptr(),
//&1 as *_,
//rhs.as_ptr(),
//&1 as *_)
//}
//}
//}
fn main() {
let x: Vec<f64> = vec!(1.0, 2.0, 3.0);
let y: Vec<f64> = vec!(3.0, 2.0, 1.0);
// BLAS
assert_eq!(x.dot(&y), 10.0);
let x: Vec<f32> = vec!(1.0, 2.0, 3.0);
let y: Vec<f32> = vec!(3.0, 2.0, 1.0);
// BLAS
assert_eq!(x.dot(&y), 10.0);
let x = vec!(1, 2, 3);
let y = vec!(3, 2, 1);
// fallback
assert_eq!(x.dot(&y), 10);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment