Last active
March 30, 2017 00:42
-
-
Save wolfv/6b310a38927867ea4a8a451eb9d6f80e to your computer and use it in GitHub Desktop.
BLAS and LAPACK functions for xtensor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/*************************************************************************** | |
* Copyright (c) 2016, Johan Mabille and Sylvain Corlay * | |
* * | |
* Distributed under the terms of the BSD 3-Clause License. * | |
* * | |
* The full license is in the file LICENSE, distributed with this software. * | |
****************************************************************************/ | |
#ifndef XBLAS_HPP | |
#define XBLAS_HPP | |
#include "xtensor/xeval.hpp" | |
#ifndef USE_CXXLAPACK | |
#define USE_CXXLAPACK | |
#endif | |
#include <flens/flens.cxx> | |
namespace xt | |
{ | |
namespace blas | |
{ | |
/** | |
* Calculate the dot product between two vectors | |
* @param a vector of n elements | |
* @param b vector of n elements | |
* @returns scalar result | |
*/ | |
template <class E1, class E2> | |
xtensor<typename E1::value_type, 1> dot(const xexpression<E1>& a, const xexpression<E2>& b) | |
{ | |
xtensor<typename E1::value_type, 1> res({0}); | |
auto&& ad = xt::eval(a.derived_cast()); | |
auto&& bd = xt::eval(b.derived_cast()); | |
cxxblas::dot(ad.size(), | |
ad.raw_data() + ad.raw_data_offset(), ad.strides()[0], | |
bd.raw_data() + bd.raw_data_offset(), bd.strides()[0], | |
res(0)); | |
return res; | |
} | |
/** | |
* Calculate the 1-norm of a vector | |
* @param a vector of n elements | |
* @returns scalar result | |
*/ | |
template <class E1> | |
typename E1::value_type asum(const xexpression<E1>& a) | |
{ | |
typename E1::value_type res; | |
auto&& ad = xt::eval(a.derived_cast()); | |
cxxblas::asum(ad.size(), | |
ad.raw_data() + ad.raw_data_offset(), ad.strides()[0], | |
res); | |
return res; | |
} | |
/** | |
* Calculate the 2-norm of a vector | |
* @param a vector of n elements | |
* @returns scalar result | |
*/ | |
template <class E1> | |
typename E1::value_type nrm2(const xexpression<E1>& a) | |
{ | |
typename E1::value_type res; | |
auto&& ad = xt::eval(a.derived_cast()); | |
cxxblas::nrm2(ad.size(), | |
ad.raw_data() + ad.raw_data_offset(), ad.strides()[0], | |
res); | |
return res; | |
} | |
namespace detail | |
{ | |
template<class xF, class xR, class fE, class... xE> | |
std::true_type is_xfunction_impl(const xfunction<xF, xR, fE, xE...>&); | |
std::false_type is_xfunction_impl( ... ); | |
template<typename T> | |
constexpr bool is_xfunction(T&& t) { | |
return decltype(is_xfunction_impl(t))::value; | |
} | |
template <class A, class B, class T> | |
struct get_type_impl { | |
using type = xarray<T>; | |
}; | |
template <class A, std::size_t N, class B, std::size_t M, class T> | |
struct get_type_impl<std::array<A, N>, std::array<B, M>, T> { | |
using type = xtensor<T, M>; | |
}; | |
template <class A, class B> | |
struct select_xtype | |
{ | |
using type = typename std::remove_const_t< | |
typename get_type_impl<typename A::shape_type, typename B::shape_type, | |
typename std::common_type_t<typename A::value_type, | |
typename B::value_type>>::type>; | |
}; | |
} | |
template <class E1, class E2> | |
xt::xarray<typename E1::value_type> gemv(const xexpression<E1>& A, const xexpression<E2>& x, | |
bool transpose = false, | |
const xscalar<typename E1::value_type> alpha = 1) | |
{ | |
// gemv calculates y = alpha*A*x + beta*y | |
auto&& dA = xt::eval(A.derived_cast()); | |
auto&& dx = xt::eval(x.derived_cast()); | |
using result_type = typename detail::select_xtype<E1, E2>::type; | |
typename result_type::shape_type result_shape(dx.shape()); | |
if (transpose) | |
{ | |
std::reverse(result_shape.begin(), result_shape.end()); | |
} | |
result_type res(result_shape); | |
cxxblas::gemv( | |
cxxblas::StorageOrder::RowMajor, | |
transpose ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans, | |
dA.shape()[0], dA.shape()[1], | |
alpha(), // alpha | |
dA.raw_data(), dA.strides()[0], | |
dx.raw_data(), dx.strides()[0], | |
0.f, // beta | |
res.data().data(), 1ul | |
); | |
return res; | |
} | |
template <class xF, class xR, class fE, class... xE, class E2> | |
std::enable_if_t<std::is_same<fE, xt::xscalar<xR>>::value && std::is_same<xF, std::multiplies<xR>>::value && | |
!detail::is_xfunction<E2>(), xt::xarray<xR>> | |
gemv(const xexpression<xfunction<xF, xR, fE, xE...>>& A, const xexpression<E2>& x, bool transpose = false) | |
{ | |
auto alpha = A.derived_cast().left(); | |
xt::xarray<xR> dA = std::move(A.derived_cast().right()); | |
return gemv(dA, x, transpose, alpha); | |
} | |
template <class aF, class aR, class... aE, class xF, class xR, class... xE> | |
xt::xarray<xR> | |
gemv(const xexpression<xfunction<aF, aR, aE...>>& A, const xexpression<xfunction<xF, xR, xE...>>& x, bool transpose = false) | |
{ | |
xt::xarray<xR> x_eval = x; | |
return gemv(A, x_eval, transpose); | |
} | |
/** | |
* Calculate the matrix-matrix product of matrix @A and matrix @B | |
* @param A matrix of m-by-n elements | |
* @param B matrix of n-by-k elements | |
* @returns matrix of m-by-k elements | |
*/ | |
template <class E1, class E2> | |
xarray<typename E1::value_type> gemm(const xexpression<E1>& A, const xexpression<E2>& B) | |
{ | |
auto&& da = xt::eval(A.derived_cast()); | |
auto&& db = xt::eval(B.derived_cast()); | |
using return_type = typename blas::detail::select_xtype<E1, E2>::type; | |
typename return_type::shape_type s = {da.shape()[0], db.shape()[1]}; | |
return_type res(s); | |
cxxblas::gemm( | |
cxxblas::StorageOrder::RowMajor, | |
cxxblas::Transpose::NoTrans, cxxblas::Transpose::NoTrans, | |
da.shape()[0], da.shape()[1], db.shape()[0], | |
1.f, // alpha | |
da.raw_data(), da.strides()[0], | |
db.raw_data(), db.strides()[0], | |
0.f, // beta | |
res.raw_data(), res.strides()[0] | |
); | |
return res; | |
} | |
template <class E1, class E2> | |
xarray<typename E1::value_type> gesv(const xexpression<E1>& A, const xexpression<E2>& B) | |
{ | |
const E1& a = A.derived_cast(); | |
xt::xarray<typename E1::value_type> da(a.shape(), layout::column_major); | |
assign_data(da, a, true); | |
const E2& b = B.derived_cast(); | |
xt::xarray<typename E1::value_type> db(b.shape(), layout::column_major); | |
assign_data(db, b, true); | |
assert(da.dimension() == 2); | |
using result_type = xarray<typename E1::value_type>; | |
typename result_type::shape_type s = {da.shape()[0]}; | |
std::vector<int> piv(da.shape()[0]); | |
cxxlapack::gesv<int>( | |
da.shape()[0], db.shape()[1], | |
da.raw_data(), da.strides()[1], | |
piv.data(), | |
db.raw_data(), !db.strides()[1] ? da.shape()[0] : db.strides()[1] | |
); | |
return db; | |
} | |
} | |
namespace linalg | |
{ | |
template <class E1> | |
typename E1::value_type norm(const xexpression<E1>& a, int ord = 2) | |
{ | |
if (ord == 1) | |
{ | |
return blas::asum(a); | |
} | |
else if (ord == 2) | |
{ | |
return blas::nrm2(a); | |
} | |
else { | |
std::cout << "Norm " << ord << " not implemented!" << std::endl; | |
} | |
return -1; | |
} | |
template <class E1, class E2> | |
auto solve(const xexpression<E1>& a, const xexpression<E2>& b) | |
{ | |
return blas::gesv(a, b); | |
} | |
template <class T, class O> | |
typename xt::blas::detail::select_xtype<T, O>::type | |
dot(const T& t, const O& o) { | |
if (t.dimension() == 1 && o.dimension() == 1) | |
{ | |
return xt::blas::dot(t, o); | |
} | |
else | |
{ | |
if (t.dimension() == 2 && o.dimension() == 1) | |
{ | |
return xt::blas::gemv(t, o); | |
} | |
else if (t.dimension() == 1 && o.dimension() == 2) | |
{ | |
return xt::blas::gemv(o, t, true); | |
} | |
else if (t.dimension() == 2 && o.dimension() == 2) | |
{ | |
return xt::blas::gemm(o, t); | |
} | |
} | |
throw std::exception(); | |
} | |
} | |
template <class E1, class E2> | |
auto cross(const xexpression<E1>& a, const xexpression<E2>& b) | |
{ | |
using return_type = xtensor<typename E1::value_type, 1>; | |
return_type res(typename return_type::shape_type{3}); | |
const E1& da = a.derived_cast(); | |
const E2& db = b.derived_cast(); | |
if (da.size() == 3 && db.size() == 3) | |
{ | |
res(0) = da(1) * db(2) - da(2) * db(1); | |
res(1) = da(2) * db(0) - da(0) * db(2); | |
res(2) = da(0) * db(1) - da(1) * db(0); | |
} | |
else if (da.size() == 2 && db.size() == 3) | |
{ | |
res(0) = da(1) * db(2); | |
res(1) = -(da(0) * db(2)); | |
res(2) = da(0) * db(1) - da(1) * db(0); | |
} | |
else if (da.size() == 3 && db.size() == 2) | |
{ | |
res(0) = -(da(2) * db(1)); | |
res(1) = da(2) * db(0); | |
res(2) = da(0) * db(1) - da(1) * db(0); | |
} | |
else if (da.size() == 2 && db.size() == 2) | |
{ | |
res(0) = 0; | |
res(1) = 0; | |
res(2) = da(0) * db(1) - da(1) * db(0); | |
} | |
else | |
{ | |
throw std::exception(); | |
} | |
return res; | |
} | |
} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment