Skip to content

Instantly share code, notes, and snippets.

@t-abe
Created July 10, 2011 10:26
Show Gist options
  • Save t-abe/1074448 to your computer and use it in GitHub Desktop.
Save t-abe/1074448 to your computer and use it in GitHub Desktop.
BLASのラッパー
#pragma once
/* using MKL10 */
#pragma comment(lib, "mkl_intel_lp64.lib")
#pragma comment(lib, "mkl_intel_thread.lib")
#pragma comment(lib, "mkl_core.lib")
#pragma comment(lib, "libiomp5md.lib")
#include <mkl.h>
#include <cmath>
/*
#include <clapack.h>
#include <cblas.h>
*/
namespace blas {
static int iamax(const int n, const float* x, const int incx){
return (int)cblas_isamax(n, x, incx);
}
static int iamax(const int n, const double* x, const int incx){
return (int)cblas_idamax(n, x, incx);
}
static float sum(const int n, const float* x, const int incx){
return cblas_sasum(n, x, incx);
}
static double sum(const int n, const double* x, const int incx){
return cblas_dasum(n, x, incx);
}
static float dot(const int n, const float* x, const int incx, const float* y, const int incy){
return cblas_sdot(n, x, incx, y, incy);
}
static double dot(const int n, const double* x, const int incx, const double* y, const int incy){
return cblas_ddot(n, x, incx, y, incy);
}
static void axpy(const int n, const float alpha, const float* x, const int incx, float* y, const int incy){
cblas_saxpy(n, alpha, x, incx, y, incy);
}
static void axpy(const int n, const double alpha, const double* x, const int incx, double* y, const int incy){
cblas_daxpy(n, alpha, x, incx, y, incy);
}
static void gemm(const bool transx, const bool transy, const int m, const int n, const int k,
const float alpha, const float* x, const int lda,
const float* y, const int ldb,
const float beta, float* z, const int ldc)
{
cblas_sgemm(CblasColMajor,
transx ? CblasTrans : CblasNoTrans, transy ? CblasTrans : CblasNoTrans,
m, n, k,
alpha, x, lda, y, ldb, beta, z, ldc);
}
static void gemm(const bool transx, const bool transy, const int m, const int n, const int k,
const double alpha,
const double* x, const int lda,
const double* y, const int ldb,
const double beta, double* z, const int ldc)
{
cblas_dgemm(CblasColMajor,
transx ? CblasTrans : CblasNoTrans, transy ? CblasTrans : CblasNoTrans,
m, n, k,
alpha, x, lda, y, ldb, beta, z, ldc);
}
static void scal(const int n, const float alpha, float* x, const int incx){
cblas_sscal(n, alpha, x, incx);
}
static void scal(const int n, const double alpha, double* x, const int incx){
cblas_dscal(n, alpha, x, incx);
}
static void copy(const int n, const float* x, const int incx, float* y, const int incy){
cblas_scopy(n, x, incx, y, incy);
}
static void copy(const int n, const double* x, const int incx, double* y, const int incy){
cblas_dcopy(n, x, incx, y, incy);
}
static float nrm2(const int n, const float* x, const int incx){
return cblas_snrm2(n, x, incx);
}
static double nrm2(const int n, const double* x, const int incx){
return cblas_dnrm2(n, x, incx);
}
/* lapack */
static void getrf(int m, int n, float* a, int lda, int* ipiv, int* info){
sgetrf(&m, &n, a, &lda, ipiv, info);
}
static void getrf(int m, int n, double* a, int lda, int* ipiv, int* info){
dgetrf(&m, &n, a, &lda, ipiv, info);
}
static void getri(int n, float* a, int lda, int* ipiv, float* work, int lwork, int* info){
sgetri(&n, a, &lda, ipiv, work, &lwork, info);
}
static void getri(int n, double* a, int lda, int* ipiv, double* work, int lwork, int* info){
dgetri(&n, a, &lda, ipiv, work, &lwork, info);
}
static void gesvd(
char jobu, char jobvt, int m, int n, float* a, int lda,
float* s, float* u, int ldu, float* vt, int ldvt,
float* work, int lwork, int* info
){
sgesvd(&jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, info);
}
static void gesvd(
char jobu, char jobvt, int m, int n, double* a, int lda,
double* s, double* u, int ldu, double* vt, int ldvt,
double* work, int lwork, int* info
){
dgesvd(&jobu, &jobvt, &m, &n, a, &lda,
s, u, &ldu, vt, &ldvt, work, &lwork, info);
}
static void syev(char jobz, char uplo, int n, float* a, int lda, float* w, float* work, int lwork, int* info){
ssyev(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
}
static void syev(char jobz, char uplo, int n, double* a, int lda, double* w, double* work, int lwork, int* info){
dsyev(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
}
/* original */
static float sigma(const int n, const float* x, const int incx){
float s = 0;
for(int i=0; i < n; i += incx)
s += x[i];
return s;
}
static double sigma(const int n, const double* x, const int incx){
double s = 0;
for(int i=0; i < n; i += incx)
s += x[i];
return s;
}
static void normalize(const int n, float* x, const int incx){
scal(n, 1/nrm2(n, x, incx), x, incx);
}
static void normalize(const int n, double* x, const int incx){
scal(n, 1/nrm2(n, x, incx), x, incx);
}
static bool equal(const int n, const float* x, const int incx, const float* y, const int incy, const float e){
for(int i=0, j=0; i < n; i += incx, j += incy)
if(abs(x[i] - y[i]) > e) return false;
return true;
}
static bool equal(const int n, const double* x, const int incx, const double* y, const int incy, const double e){
for(int i=0, j=0; i < n; i += incx, j += incy)
if(abs(x[i] - y[i]) > e) return false;
return true;
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment