Created
July 10, 2011 10:26
-
-
Save t-abe/1074448 to your computer and use it in GitHub Desktop.
BLASのラッパー
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
/* using MKL10 */ | |
#pragma comment(lib, "mkl_intel_lp64.lib") | |
#pragma comment(lib, "mkl_intel_thread.lib") | |
#pragma comment(lib, "mkl_core.lib") | |
#pragma comment(lib, "libiomp5md.lib") | |
#include <mkl.h> | |
#include <cmath> | |
/* | |
#include <clapack.h> | |
#include <cblas.h> | |
*/ | |
namespace blas { | |
static int iamax(const int n, const float* x, const int incx){ | |
return (int)cblas_isamax(n, x, incx); | |
} | |
static int iamax(const int n, const double* x, const int incx){ | |
return (int)cblas_idamax(n, x, incx); | |
} | |
static float sum(const int n, const float* x, const int incx){ | |
return cblas_sasum(n, x, incx); | |
} | |
static double sum(const int n, const double* x, const int incx){ | |
return cblas_dasum(n, x, incx); | |
} | |
static float dot(const int n, const float* x, const int incx, const float* y, const int incy){ | |
return cblas_sdot(n, x, incx, y, incy); | |
} | |
static double dot(const int n, const double* x, const int incx, const double* y, const int incy){ | |
return cblas_ddot(n, x, incx, y, incy); | |
} | |
static void axpy(const int n, const float alpha, const float* x, const int incx, float* y, const int incy){ | |
cblas_saxpy(n, alpha, x, incx, y, incy); | |
} | |
static void axpy(const int n, const double alpha, const double* x, const int incx, double* y, const int incy){ | |
cblas_daxpy(n, alpha, x, incx, y, incy); | |
} | |
static void gemm(const bool transx, const bool transy, const int m, const int n, const int k, | |
const float alpha, const float* x, const int lda, | |
const float* y, const int ldb, | |
const float beta, float* z, const int ldc) | |
{ | |
cblas_sgemm(CblasColMajor, | |
transx ? CblasTrans : CblasNoTrans, transy ? CblasTrans : CblasNoTrans, | |
m, n, k, | |
alpha, x, lda, y, ldb, beta, z, ldc); | |
} | |
static void gemm(const bool transx, const bool transy, const int m, const int n, const int k, | |
const double alpha, | |
const double* x, const int lda, | |
const double* y, const int ldb, | |
const double beta, double* z, const int ldc) | |
{ | |
cblas_dgemm(CblasColMajor, | |
transx ? CblasTrans : CblasNoTrans, transy ? CblasTrans : CblasNoTrans, | |
m, n, k, | |
alpha, x, lda, y, ldb, beta, z, ldc); | |
} | |
static void scal(const int n, const float alpha, float* x, const int incx){ | |
cblas_sscal(n, alpha, x, incx); | |
} | |
static void scal(const int n, const double alpha, double* x, const int incx){ | |
cblas_dscal(n, alpha, x, incx); | |
} | |
static void copy(const int n, const float* x, const int incx, float* y, const int incy){ | |
cblas_scopy(n, x, incx, y, incy); | |
} | |
static void copy(const int n, const double* x, const int incx, double* y, const int incy){ | |
cblas_dcopy(n, x, incx, y, incy); | |
} | |
static float nrm2(const int n, const float* x, const int incx){ | |
return cblas_snrm2(n, x, incx); | |
} | |
static double nrm2(const int n, const double* x, const int incx){ | |
return cblas_dnrm2(n, x, incx); | |
} | |
/* lapack */ | |
static void getrf(int m, int n, float* a, int lda, int* ipiv, int* info){ | |
sgetrf(&m, &n, a, &lda, ipiv, info); | |
} | |
static void getrf(int m, int n, double* a, int lda, int* ipiv, int* info){ | |
dgetrf(&m, &n, a, &lda, ipiv, info); | |
} | |
static void getri(int n, float* a, int lda, int* ipiv, float* work, int lwork, int* info){ | |
sgetri(&n, a, &lda, ipiv, work, &lwork, info); | |
} | |
static void getri(int n, double* a, int lda, int* ipiv, double* work, int lwork, int* info){ | |
dgetri(&n, a, &lda, ipiv, work, &lwork, info); | |
} | |
static void gesvd( | |
char jobu, char jobvt, int m, int n, float* a, int lda, | |
float* s, float* u, int ldu, float* vt, int ldvt, | |
float* work, int lwork, int* info | |
){ | |
sgesvd(&jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, info); | |
} | |
static void gesvd( | |
char jobu, char jobvt, int m, int n, double* a, int lda, | |
double* s, double* u, int ldu, double* vt, int ldvt, | |
double* work, int lwork, int* info | |
){ | |
dgesvd(&jobu, &jobvt, &m, &n, a, &lda, | |
s, u, &ldu, vt, &ldvt, work, &lwork, info); | |
} | |
static void syev(char jobz, char uplo, int n, float* a, int lda, float* w, float* work, int lwork, int* info){ | |
ssyev(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info); | |
} | |
static void syev(char jobz, char uplo, int n, double* a, int lda, double* w, double* work, int lwork, int* info){ | |
dsyev(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info); | |
} | |
/* original */ | |
static float sigma(const int n, const float* x, const int incx){ | |
float s = 0; | |
for(int i=0; i < n; i += incx) | |
s += x[i]; | |
return s; | |
} | |
static double sigma(const int n, const double* x, const int incx){ | |
double s = 0; | |
for(int i=0; i < n; i += incx) | |
s += x[i]; | |
return s; | |
} | |
static void normalize(const int n, float* x, const int incx){ | |
scal(n, 1/nrm2(n, x, incx), x, incx); | |
} | |
static void normalize(const int n, double* x, const int incx){ | |
scal(n, 1/nrm2(n, x, incx), x, incx); | |
} | |
static bool equal(const int n, const float* x, const int incx, const float* y, const int incy, const float e){ | |
for(int i=0, j=0; i < n; i += incx, j += incy) | |
if(abs(x[i] - y[i]) > e) return false; | |
return true; | |
} | |
static bool equal(const int n, const double* x, const int incx, const double* y, const int incy, const double e){ | |
for(int i=0, j=0; i < n; i += incx, j += incy) | |
if(abs(x[i] - y[i]) > e) return false; | |
return true; | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment