Skip to content

Instantly share code, notes, and snippets.

@serialhex
Last active December 6, 2016 19:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save serialhex/72a40bdead1d33591d25fcc60300b427 to your computer and use it in GitHub Desktop.
Save serialhex/72a40bdead1d33591d25fcc60300b427 to your computer and use it in GitHub Desktop.
#include <arrayfire.h>
// Computes the element-wise dot product of 2 arrays
// The default of dim = -1 computes the dot for the first non-empty dimension
// otherwise, it is the dimension that the vectors are located in
array elem_dot(const array& a, const array& b, const int dim = -1) {
return sum(a * b, dim);
}
// computes the cross product of an array of vectors
// dim is the dimension that the vectors live in
// ***WARNING***
// I havn't tried this with vectors of size different from 3,
// so be careful! it may not work
array elem_cross(const array& a, const array& b, const int dim = 0) {
array sh_a;
array sh_b;
switch (dim) {
case 0:
sh_a = shift(a, 1);
sh_b = shift(b, 1);
break;
case 1:
sh_a = shift(a, 0, 1);
sh_b = shift(b, 0, 1);
break;
case 2:
sh_a = shift(a, 0, 0, 1);
sh_b = shift(b, 0, 0, 1);
break;
case 3:
sh_a = shift(a, 0, 0, 0, 1);
sh_b = shift(b, 0, 0, 0, 1);
break;
}
array res = (b * sh_a) - (a * sh_b);
switch (dim) {
case 0:
res = shift(res, 1);
break;
case 1:
res = shift(res, 0, 1);
break;
case 2:
res = shift(res, 0, 0, 1);
break;
case 3:
res = shift(res, 0, 0, 0, 1);
break;
}
return res;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment