Skip to content

Instantly share code, notes, and snippets.

@wolfv
Last active March 30, 2017 00:42
Show Gist options
  • Save wolfv/6b310a38927867ea4a8a451eb9d6f80e to your computer and use it in GitHub Desktop.
Save wolfv/6b310a38927867ea4a8a451eb9d6f80e to your computer and use it in GitHub Desktop.
BLAS and LAPACK functions for xtensor
/***************************************************************************
* Copyright (c) 2016, Johan Mabille and Sylvain Corlay *
* *
* Distributed under the terms of the BSD 3-Clause License. *
* *
* The full license is in the file LICENSE, distributed with this software. *
****************************************************************************/
#ifndef XBLAS_HPP
#define XBLAS_HPP
#include "xtensor/xeval.hpp"
#ifndef USE_CXXLAPACK
#define USE_CXXLAPACK
#endif
#include <flens/flens.cxx>
namespace xt
{
namespace blas
{
/**
* Calculate the dot product between two vectors
* @param a vector of n elements
* @param b vector of n elements
* @returns scalar result
*/
template <class E1, class E2>
xtensor<typename E1::value_type, 1> dot(const xexpression<E1>& a, const xexpression<E2>& b)
{
xtensor<typename E1::value_type, 1> res({0});
auto&& ad = xt::eval(a.derived_cast());
auto&& bd = xt::eval(b.derived_cast());
cxxblas::dot(ad.size(),
ad.raw_data() + ad.raw_data_offset(), ad.strides()[0],
bd.raw_data() + bd.raw_data_offset(), bd.strides()[0],
res(0));
return res;
}
/**
* Calculate the 1-norm of a vector
* @param a vector of n elements
* @returns scalar result
*/
template <class E1>
typename E1::value_type asum(const xexpression<E1>& a)
{
typename E1::value_type res;
auto&& ad = xt::eval(a.derived_cast());
cxxblas::asum(ad.size(),
ad.raw_data() + ad.raw_data_offset(), ad.strides()[0],
res);
return res;
}
/**
* Calculate the 2-norm of a vector
* @param a vector of n elements
* @returns scalar result
*/
template <class E1>
typename E1::value_type nrm2(const xexpression<E1>& a)
{
typename E1::value_type res;
auto&& ad = xt::eval(a.derived_cast());
cxxblas::nrm2(ad.size(),
ad.raw_data() + ad.raw_data_offset(), ad.strides()[0],
res);
return res;
}
namespace detail
{
template<class xF, class xR, class fE, class... xE>
std::true_type is_xfunction_impl(const xfunction<xF, xR, fE, xE...>&);
std::false_type is_xfunction_impl( ... );
template<typename T>
constexpr bool is_xfunction(T&& t) {
return decltype(is_xfunction_impl(t))::value;
}
template <class A, class B, class T>
struct get_type_impl {
using type = xarray<T>;
};
template <class A, std::size_t N, class B, std::size_t M, class T>
struct get_type_impl<std::array<A, N>, std::array<B, M>, T> {
using type = xtensor<T, M>;
};
template <class A, class B>
struct select_xtype
{
using type = typename std::remove_const_t<
typename get_type_impl<typename A::shape_type, typename B::shape_type,
typename std::common_type_t<typename A::value_type,
typename B::value_type>>::type>;
};
}
template <class E1, class E2>
xt::xarray<typename E1::value_type> gemv(const xexpression<E1>& A, const xexpression<E2>& x,
bool transpose = false,
const xscalar<typename E1::value_type> alpha = 1)
{
// gemv calculates y = alpha*A*x + beta*y
auto&& dA = xt::eval(A.derived_cast());
auto&& dx = xt::eval(x.derived_cast());
using result_type = typename detail::select_xtype<E1, E2>::type;
typename result_type::shape_type result_shape(dx.shape());
if (transpose)
{
std::reverse(result_shape.begin(), result_shape.end());
}
result_type res(result_shape);
cxxblas::gemv(
cxxblas::StorageOrder::RowMajor,
transpose ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans,
dA.shape()[0], dA.shape()[1],
alpha(), // alpha
dA.raw_data(), dA.strides()[0],
dx.raw_data(), dx.strides()[0],
0.f, // beta
res.data().data(), 1ul
);
return res;
}
template <class xF, class xR, class fE, class... xE, class E2>
std::enable_if_t<std::is_same<fE, xt::xscalar<xR>>::value && std::is_same<xF, std::multiplies<xR>>::value &&
!detail::is_xfunction<E2>(), xt::xarray<xR>>
gemv(const xexpression<xfunction<xF, xR, fE, xE...>>& A, const xexpression<E2>& x, bool transpose = false)
{
auto alpha = A.derived_cast().left();
xt::xarray<xR> dA = std::move(A.derived_cast().right());
return gemv(dA, x, transpose, alpha);
}
template <class aF, class aR, class... aE, class xF, class xR, class... xE>
xt::xarray<xR>
gemv(const xexpression<xfunction<aF, aR, aE...>>& A, const xexpression<xfunction<xF, xR, xE...>>& x, bool transpose = false)
{
xt::xarray<xR> x_eval = x;
return gemv(A, x_eval, transpose);
}
/**
* Calculate the matrix-matrix product of matrix @A and matrix @B
* @param A matrix of m-by-n elements
* @param B matrix of n-by-k elements
* @returns matrix of m-by-k elements
*/
template <class E1, class E2>
xarray<typename E1::value_type> gemm(const xexpression<E1>& A, const xexpression<E2>& B)
{
auto&& da = xt::eval(A.derived_cast());
auto&& db = xt::eval(B.derived_cast());
using return_type = typename blas::detail::select_xtype<E1, E2>::type;
typename return_type::shape_type s = {da.shape()[0], db.shape()[1]};
return_type res(s);
cxxblas::gemm(
cxxblas::StorageOrder::RowMajor,
cxxblas::Transpose::NoTrans, cxxblas::Transpose::NoTrans,
da.shape()[0], da.shape()[1], db.shape()[0],
1.f, // alpha
da.raw_data(), da.strides()[0],
db.raw_data(), db.strides()[0],
0.f, // beta
res.raw_data(), res.strides()[0]
);
return res;
}
template <class E1, class E2>
xarray<typename E1::value_type> gesv(const xexpression<E1>& A, const xexpression<E2>& B)
{
const E1& a = A.derived_cast();
xt::xarray<typename E1::value_type> da(a.shape(), layout::column_major);
assign_data(da, a, true);
const E2& b = B.derived_cast();
xt::xarray<typename E1::value_type> db(b.shape(), layout::column_major);
assign_data(db, b, true);
assert(da.dimension() == 2);
using result_type = xarray<typename E1::value_type>;
typename result_type::shape_type s = {da.shape()[0]};
std::vector<int> piv(da.shape()[0]);
cxxlapack::gesv<int>(
da.shape()[0], db.shape()[1],
da.raw_data(), da.strides()[1],
piv.data(),
db.raw_data(), !db.strides()[1] ? da.shape()[0] : db.strides()[1]
);
return db;
}
}
namespace linalg
{
template <class E1>
typename E1::value_type norm(const xexpression<E1>& a, int ord = 2)
{
if (ord == 1)
{
return blas::asum(a);
}
else if (ord == 2)
{
return blas::nrm2(a);
}
else {
std::cout << "Norm " << ord << " not implemented!" << std::endl;
}
return -1;
}
template <class E1, class E2>
auto solve(const xexpression<E1>& a, const xexpression<E2>& b)
{
return blas::gesv(a, b);
}
template <class T, class O>
typename xt::blas::detail::select_xtype<T, O>::type
dot(const T& t, const O& o) {
if (t.dimension() == 1 && o.dimension() == 1)
{
return xt::blas::dot(t, o);
}
else
{
if (t.dimension() == 2 && o.dimension() == 1)
{
return xt::blas::gemv(t, o);
}
else if (t.dimension() == 1 && o.dimension() == 2)
{
return xt::blas::gemv(o, t, true);
}
else if (t.dimension() == 2 && o.dimension() == 2)
{
return xt::blas::gemm(o, t);
}
}
throw std::exception();
}
}
template <class E1, class E2>
auto cross(const xexpression<E1>& a, const xexpression<E2>& b)
{
using return_type = xtensor<typename E1::value_type, 1>;
return_type res(typename return_type::shape_type{3});
const E1& da = a.derived_cast();
const E2& db = b.derived_cast();
if (da.size() == 3 && db.size() == 3)
{
res(0) = da(1) * db(2) - da(2) * db(1);
res(1) = da(2) * db(0) - da(0) * db(2);
res(2) = da(0) * db(1) - da(1) * db(0);
}
else if (da.size() == 2 && db.size() == 3)
{
res(0) = da(1) * db(2);
res(1) = -(da(0) * db(2));
res(2) = da(0) * db(1) - da(1) * db(0);
}
else if (da.size() == 3 && db.size() == 2)
{
res(0) = -(da(2) * db(1));
res(1) = da(2) * db(0);
res(2) = da(0) * db(1) - da(1) * db(0);
}
else if (da.size() == 2 && db.size() == 2)
{
res(0) = 0;
res(1) = 0;
res(2) = da(0) * db(1) - da(1) * db(0);
}
else
{
throw std::exception();
}
return res;
}
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment