Skip to content

Instantly share code, notes, and snippets.

@JonathanRaiman
Last active August 29, 2015 14:24
Show Gist options
  • Save JonathanRaiman/937bc144880c39f4e2dd to your computer and use it in GitHub Desktop.
Save JonathanRaiman/937bc144880c39f4e2dd to your computer and use it in GitHub Desktop.
Mshadow using Eigen's dot product engine
#ifndef DALI_MATH_MSHADOW_EIGEN_DOT_H
#define DALI_MATH_MSHADOW_EIGEN_DOT_H
#if MSHADOW_USE_EIGEN_DOT
#include <Eigen/Eigen>
#include <mshadow/tensor.h>
#include <cblas.h>
// Eigen Backend for Dot-Product in Mshadow
// Causes Adagrad to be slower.
namespace mshadow {
namespace expr {
template<>
struct BLASEngine<cpu> {
typedef Eigen::Map<Eigen::Matrix<float,Eigen::Dynamic,Eigen::Dynamic, Eigen::RowMajor> > eigen_float_mat_t;
typedef Eigen::Map<Eigen::Matrix<double,Eigen::Dynamic,Eigen::Dynamic, Eigen::RowMajor> > eigen_double_mat_t;
typedef Eigen::Map<Eigen::VectorXf > eigen_float_vector_t;
typedef Eigen::Map<Eigen::VectorXd > eigen_double_vector_t;
inline static void SetStream(Stream<cpu> *stream) {
}
inline static void gemm(bool trans_rhs, bool trans_lhs,
int m, int n, int k, float alpha,
const float *rhs_ptr, int ldb, const float *lhs_ptr, int lda,
float beta, float *C, int ldc) {
// collect sizes from arguments.
// Thanks blas!
const int rhs_size1 = trans_rhs ? k : m;
const int rhs_size0 = trans_rhs ? m : k;
const int lhs_size1 = trans_lhs ? n : k;
const int lhs_size0 = trans_lhs ? k : n;
const eigen_float_mat_t lhs(
const_cast<float*>(lhs_ptr), lhs_size0, lhs_size1
);
const eigen_float_mat_t rhs(
const_cast<float*>(rhs_ptr), rhs_size0, rhs_size1
);
eigen_float_mat_t dst(
C,
trans_lhs ? lhs_size1 : lhs_size0,
trans_rhs ? rhs_size0 : rhs_size1
);
if (trans_lhs && trans_rhs) {
if (beta != 0.0f) {
dst.noalias() = alpha * lhs.transpose() * rhs.transpose() + dst * beta;
} else {
dst.noalias() = alpha * lhs.transpose() * rhs.transpose();
}
} else if (trans_lhs && !trans_rhs) {
if (beta != 0.0f) {
dst.noalias() = alpha * lhs.transpose() * rhs + dst * beta;
} else {
dst.noalias() = alpha * lhs.transpose() * rhs;
}
} else if (!trans_lhs && trans_rhs) {
if (beta != 0.0f) {
dst.noalias() = alpha * lhs * rhs.transpose() + dst * beta;
} else {
dst.noalias() = alpha * lhs * rhs.transpose();
}
} else if (!trans_lhs && !trans_rhs) {
if (beta != 0.0f) {
dst.noalias() = alpha * lhs * rhs + dst * beta;
} else {
dst.noalias() = alpha * lhs * rhs;
}
}
}
inline static CBLAS_TRANSPOSE GetT(bool t) {
return t ? CblasTrans : CblasNoTrans;
}
inline static void gemm(bool trans_rhs, bool trans_lhs,
int m, int n, int k, double alpha,
const double *rhs_ptr, int ldb, const double *lhs_ptr, int lda,
double beta, double *C, int ldc) {
// collect sizes from arguments.
// Thanks blas!
const int rhs_size1 = trans_rhs ? k : m;
const int rhs_size0 = trans_rhs ? m : k;
const int lhs_size1 = trans_lhs ? n : k;
const int lhs_size0 = trans_lhs ? k : n;
const eigen_double_mat_t lhs(
const_cast<double*>(lhs_ptr), lhs_size0, lhs_size1
);
const eigen_double_mat_t rhs(
const_cast<double*>(rhs_ptr), rhs_size0, rhs_size1
);
eigen_double_mat_t dst(
C,
trans_lhs ? lhs_size1 : lhs_size0,
trans_rhs ? rhs_size0 : rhs_size1
);
if (trans_lhs && trans_rhs) {
if (beta != 0.0f) {
dst.noalias() = alpha * lhs.transpose() * rhs.transpose() + dst * beta;
} else {
dst.noalias() = alpha * lhs.transpose() * rhs.transpose();
}
} else if (trans_lhs && !trans_rhs) {
if (beta != 0.0f) {
dst.noalias() = alpha * lhs.transpose() * rhs + dst * beta;
} else {
dst.noalias() = alpha * lhs.transpose() * rhs;
}
} else if (!trans_lhs && trans_rhs) {
if (beta != 0.0f) {
dst.noalias() = alpha * lhs * rhs.transpose() + dst * beta;
} else {
dst.noalias() = alpha * lhs * rhs.transpose();
}
} else if (!trans_lhs && !trans_rhs) {
if (beta != 0.0f) {
dst.noalias() = alpha * lhs * rhs + dst * beta;
} else {
dst.noalias() = alpha * lhs * rhs;
}
}
}
inline static void gemv(bool trans_rhs, int rhs_size1, int rhs_size0,
float alpha,
const float *rhs_ptr, int rhs_stride,
const float *lhs_ptr, int lhs_stride,
float beta,
float *dst_ptr, int dst_stride) {
const eigen_float_mat_t rhs(
const_cast<float*>(rhs_ptr), rhs_size0, rhs_size1
);
const eigen_float_vector_t lhs(
const_cast<float*>(X), n
);
eigen_float_vector_t Y_eigen(
Y, m
);
if (trans) {
if (beta != 0.0f) {
Y_eigen.noalias() = alpha * A_eigen.transpose() * X_eigen + Y_eigen * beta;
} else {
Y_eigen.noalias() = alpha * A_eigen.transpose() * X_eigen;
}
} else {
if (beta != 0.0f) {
Y_eigen.noalias() = alpha * A_eigen * X_eigen + Y_eigen * beta;
} else {
Y_eigen.noalias() = alpha * A_eigen * X_eigen;
}
}
}
inline static void gemv(bool trans, int n, int m, double alpha,
const double *X, int lda,
const double *A, int incX,
double beta, double *Y, int incY) {
const eigen_double_mat_t A_eigen(
const_cast<double*>(A), trans ? n : m, trans ? m : n
);
const eigen_double_vector_t X_eigen(
const_cast<double*>(X), n
);
eigen_double_vector_t Y_eigen(
Y, m
);
if (trans) {
if (beta != 0.0f) {
Y_eigen = alpha * A_eigen.transpose() * X_eigen + Y_eigen * beta;
} else {
Y_eigen.noalias() = alpha * A_eigen.transpose() * X_eigen;
}
} else {
if (beta != 0.0f) {
Y_eigen = alpha * A_eigen * X_eigen + Y_eigen * beta;
} else {
Y_eigen.noalias() = alpha * A_eigen * X_eigen;
}
}
}
// outer product
inline static void ger(int n, int m, float alpha,
const float *Y, int incY,
const float *X, int incX, float *A, int lda) {
eigen_float_mat_t A_eigen(
A, m, n
);
const eigen_float_vector_t X_eigen(
const_cast<float*>(X), m
);
const eigen_float_vector_t Y_eigen(
const_cast<float*>(Y), n
);
A_eigen.noalias() += alpha * X_eigen * Y_eigen.transpose();
}
inline static void ger(int n, int m, double alpha,
const double *Y, int incX,
const double *X, int incY, double *A, int lda) {
eigen_double_mat_t A_eigen(
A, m, n
);
const eigen_double_vector_t X_eigen(
const_cast<double*>(X), m
);
const eigen_double_vector_t Y_eigen(
const_cast<double*>(Y), n
);
A_eigen.noalias() += alpha * X_eigen * Y_eigen.transpose();
}
};
} // namespace expr
} // namespace mshadow
#endif
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment