Last active
August 29, 2015 14:24
-
-
Save JonathanRaiman/937bc144880c39f4e2dd to your computer and use it in GitHub Desktop.
Mshadow using Eigen's dot product engine
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef DALI_MATH_MSHADOW_EIGEN_DOT_H | |
#define DALI_MATH_MSHADOW_EIGEN_DOT_H | |
#if MSHADOW_USE_EIGEN_DOT | |
#include <Eigen/Eigen> | |
#include <mshadow/tensor.h> | |
#include <cblas.h> | |
// Eigen Backend for Dot-Product in Mshadow | |
// Causes Adagrad to be slower. | |
namespace mshadow { | |
namespace expr { | |
template<> | |
struct BLASEngine<cpu> { | |
typedef Eigen::Map<Eigen::Matrix<float,Eigen::Dynamic,Eigen::Dynamic, Eigen::RowMajor> > eigen_float_mat_t; | |
typedef Eigen::Map<Eigen::Matrix<double,Eigen::Dynamic,Eigen::Dynamic, Eigen::RowMajor> > eigen_double_mat_t; | |
typedef Eigen::Map<Eigen::VectorXf > eigen_float_vector_t; | |
typedef Eigen::Map<Eigen::VectorXd > eigen_double_vector_t; | |
inline static void SetStream(Stream<cpu> *stream) { | |
} | |
inline static void gemm(bool trans_rhs, bool trans_lhs, | |
int m, int n, int k, float alpha, | |
const float *rhs_ptr, int ldb, const float *lhs_ptr, int lda, | |
float beta, float *C, int ldc) { | |
// collect sizes from arguments. | |
// Thanks blas! | |
const int rhs_size1 = trans_rhs ? k : m; | |
const int rhs_size0 = trans_rhs ? m : k; | |
const int lhs_size1 = trans_lhs ? n : k; | |
const int lhs_size0 = trans_lhs ? k : n; | |
const eigen_float_mat_t lhs( | |
const_cast<float*>(lhs_ptr), lhs_size0, lhs_size1 | |
); | |
const eigen_float_mat_t rhs( | |
const_cast<float*>(rhs_ptr), rhs_size0, rhs_size1 | |
); | |
eigen_float_mat_t dst( | |
C, | |
trans_lhs ? lhs_size1 : lhs_size0, | |
trans_rhs ? rhs_size0 : rhs_size1 | |
); | |
if (trans_lhs && trans_rhs) { | |
if (beta != 0.0f) { | |
dst.noalias() = alpha * lhs.transpose() * rhs.transpose() + dst * beta; | |
} else { | |
dst.noalias() = alpha * lhs.transpose() * rhs.transpose(); | |
} | |
} else if (trans_lhs && !trans_rhs) { | |
if (beta != 0.0f) { | |
dst.noalias() = alpha * lhs.transpose() * rhs + dst * beta; | |
} else { | |
dst.noalias() = alpha * lhs.transpose() * rhs; | |
} | |
} else if (!trans_lhs && trans_rhs) { | |
if (beta != 0.0f) { | |
dst.noalias() = alpha * lhs * rhs.transpose() + dst * beta; | |
} else { | |
dst.noalias() = alpha * lhs * rhs.transpose(); | |
} | |
} else if (!trans_lhs && !trans_rhs) { | |
if (beta != 0.0f) { | |
dst.noalias() = alpha * lhs * rhs + dst * beta; | |
} else { | |
dst.noalias() = alpha * lhs * rhs; | |
} | |
} | |
} | |
inline static CBLAS_TRANSPOSE GetT(bool t) { | |
return t ? CblasTrans : CblasNoTrans; | |
} | |
inline static void gemm(bool trans_rhs, bool trans_lhs, | |
int m, int n, int k, double alpha, | |
const double *rhs_ptr, int ldb, const double *lhs_ptr, int lda, | |
double beta, double *C, int ldc) { | |
// collect sizes from arguments. | |
// Thanks blas! | |
const int rhs_size1 = trans_rhs ? k : m; | |
const int rhs_size0 = trans_rhs ? m : k; | |
const int lhs_size1 = trans_lhs ? n : k; | |
const int lhs_size0 = trans_lhs ? k : n; | |
const eigen_double_mat_t lhs( | |
const_cast<double*>(lhs_ptr), lhs_size0, lhs_size1 | |
); | |
const eigen_double_mat_t rhs( | |
const_cast<double*>(rhs_ptr), rhs_size0, rhs_size1 | |
); | |
eigen_double_mat_t dst( | |
C, | |
trans_lhs ? lhs_size1 : lhs_size0, | |
trans_rhs ? rhs_size0 : rhs_size1 | |
); | |
if (trans_lhs && trans_rhs) { | |
if (beta != 0.0f) { | |
dst.noalias() = alpha * lhs.transpose() * rhs.transpose() + dst * beta; | |
} else { | |
dst.noalias() = alpha * lhs.transpose() * rhs.transpose(); | |
} | |
} else if (trans_lhs && !trans_rhs) { | |
if (beta != 0.0f) { | |
dst.noalias() = alpha * lhs.transpose() * rhs + dst * beta; | |
} else { | |
dst.noalias() = alpha * lhs.transpose() * rhs; | |
} | |
} else if (!trans_lhs && trans_rhs) { | |
if (beta != 0.0f) { | |
dst.noalias() = alpha * lhs * rhs.transpose() + dst * beta; | |
} else { | |
dst.noalias() = alpha * lhs * rhs.transpose(); | |
} | |
} else if (!trans_lhs && !trans_rhs) { | |
if (beta != 0.0f) { | |
dst.noalias() = alpha * lhs * rhs + dst * beta; | |
} else { | |
dst.noalias() = alpha * lhs * rhs; | |
} | |
} | |
} | |
inline static void gemv(bool trans_rhs, int rhs_size1, int rhs_size0, | |
float alpha, | |
const float *rhs_ptr, int rhs_stride, | |
const float *lhs_ptr, int lhs_stride, | |
float beta, | |
float *dst_ptr, int dst_stride) { | |
const eigen_float_mat_t rhs( | |
const_cast<float*>(rhs_ptr), rhs_size0, rhs_size1 | |
); | |
const eigen_float_vector_t lhs( | |
const_cast<float*>(X), n | |
); | |
eigen_float_vector_t Y_eigen( | |
Y, m | |
); | |
if (trans) { | |
if (beta != 0.0f) { | |
Y_eigen.noalias() = alpha * A_eigen.transpose() * X_eigen + Y_eigen * beta; | |
} else { | |
Y_eigen.noalias() = alpha * A_eigen.transpose() * X_eigen; | |
} | |
} else { | |
if (beta != 0.0f) { | |
Y_eigen.noalias() = alpha * A_eigen * X_eigen + Y_eigen * beta; | |
} else { | |
Y_eigen.noalias() = alpha * A_eigen * X_eigen; | |
} | |
} | |
} | |
inline static void gemv(bool trans, int n, int m, double alpha, | |
const double *X, int lda, | |
const double *A, int incX, | |
double beta, double *Y, int incY) { | |
const eigen_double_mat_t A_eigen( | |
const_cast<double*>(A), trans ? n : m, trans ? m : n | |
); | |
const eigen_double_vector_t X_eigen( | |
const_cast<double*>(X), n | |
); | |
eigen_double_vector_t Y_eigen( | |
Y, m | |
); | |
if (trans) { | |
if (beta != 0.0f) { | |
Y_eigen = alpha * A_eigen.transpose() * X_eigen + Y_eigen * beta; | |
} else { | |
Y_eigen.noalias() = alpha * A_eigen.transpose() * X_eigen; | |
} | |
} else { | |
if (beta != 0.0f) { | |
Y_eigen = alpha * A_eigen * X_eigen + Y_eigen * beta; | |
} else { | |
Y_eigen.noalias() = alpha * A_eigen * X_eigen; | |
} | |
} | |
} | |
// outer product | |
inline static void ger(int n, int m, float alpha, | |
const float *Y, int incY, | |
const float *X, int incX, float *A, int lda) { | |
eigen_float_mat_t A_eigen( | |
A, m, n | |
); | |
const eigen_float_vector_t X_eigen( | |
const_cast<float*>(X), m | |
); | |
const eigen_float_vector_t Y_eigen( | |
const_cast<float*>(Y), n | |
); | |
A_eigen.noalias() += alpha * X_eigen * Y_eigen.transpose(); | |
} | |
inline static void ger(int n, int m, double alpha, | |
const double *Y, int incX, | |
const double *X, int incY, double *A, int lda) { | |
eigen_double_mat_t A_eigen( | |
A, m, n | |
); | |
const eigen_double_vector_t X_eigen( | |
const_cast<double*>(X), m | |
); | |
const eigen_double_vector_t Y_eigen( | |
const_cast<double*>(Y), n | |
); | |
A_eigen.noalias() += alpha * X_eigen * Y_eigen.transpose(); | |
} | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif | |
#endif | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment