Created
July 10, 2011 10:25
-
-
Save t-abe/1074446 to your computer and use it in GitHub Desktop.
BLASとLAPACKのラッパー
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* the wrapper of BLAS and LAPACK (MKL) */ | |
#pragma once | |
#include "blas.hpp" | |
#include <cstring> | |
#ifndef MAX | |
#define MAX(a, b) (((a) > (b)) ? (a) : (b)) | |
#endif | |
#ifndef MIN | |
#define MIN(a, b) (((a) < (b)) ? (a) : (b)) | |
#endif | |
#ifndef ZERO | |
#define ZERO(mat, n) \ | |
{ T _ZERO = 0; blas::copy(n, &_ZERO, 0, mat, 1); }; | |
#endif | |
namespace wbl | |
{ | |
// X := x or x' (depends on transx) | |
// Y := y or y' (depends on transy) | |
// z <- alpha * XY + beta * z | |
template <class T> | |
void gemm(const T alpha, | |
const bool transx, const T* x, const int xrows, const int xcols, | |
const bool transy, const T* y, const int yrows, const int ycols, | |
const T beta, T* z | |
) | |
{ | |
int m, n, k; | |
int lda, ldb, ldc; | |
m = !transx ? xrows : xcols; | |
n = !transy ? ycols : yrows; | |
k = !transx ? xcols : xrows; // = y.rows : y.cols | |
lda = xrows; | |
ldb = yrows; | |
ldc = m; | |
blas::gemm( | |
transx, transy, | |
m, n, k, | |
alpha, x, lda, | |
y, ldb, | |
beta, z, ldc); | |
}; | |
// tr_mat <- mat' | |
template <class T> | |
void transpose(const T* mat, const int rows, const int cols, T* tr_mat) | |
{ | |
for(int i=0; i < cols; i++) | |
blas::copy(rows, &mat[rows * i], 1, &tr_mat[i], cols); | |
}; | |
// tmat <- mat^-1 | |
template <class T> | |
void inverse(const T* mat, const int rows, const int cols, T* tmat) | |
{ | |
assert(rows == cols); | |
blas::copy(rows * cols, mat, 1, tmat, 1); | |
// LU decomposition | |
int info; | |
int *ipiv = new int[rows]; | |
blas::getrf(rows, rows, tmat, rows, ipiv, &info); | |
// inverse | |
int lwork = rows * 64; | |
T* work = new T[lwork]; | |
blas::getri(rows, tmat, rows, ipiv, | |
work, lwork, &info); | |
delete[] ipiv; | |
delete[] work; | |
}; | |
// u(r, c), d(c, c), vt(c, c) | |
template <class T> | |
void svd(const T* mat, const int rows, const int cols, T* u, T* d, T* vt) | |
{ | |
int info, i; | |
blas::copy(rows * cols, mat, 1, u, 1); | |
ZERO(d, cols * cols); | |
ZERO(vt, cols * cols); | |
int lwork = MAX( 3 * cols + rows, | |
5 * cols ); | |
T* work = new T[lwork]; | |
T* vd = new T[cols]; | |
ZERO(vd, cols); | |
if( rows < cols ){ | |
blas::gesvd('O', 'S', rows, cols, u, rows, | |
vd, NULL, 1, vt, cols, | |
work, lwork, &info); | |
for(i=0; i < cols - rows; i++) | |
memset(&u[rows * (rows + i)], 0, sizeof(T) * rows); | |
} else { | |
blas::gesvd('O', 'S', rows, cols, u, rows, | |
vd, NULL, 1, vt, cols, | |
work, lwork, &info); | |
} | |
for(i=0; i < cols; i++) | |
d[i * cols + i] = vd[i]; | |
delete[] vd; | |
delete[] work; | |
} | |
template <class T> | |
void pinv(const T* mat, const int rows, const int cols, T* inv, int rank = -1) | |
{ | |
assert(rows >= cols); | |
if( rank == -1 ) rank = cols; | |
T *U, *D, *Vt; | |
U = new T[rows * cols]; | |
D = new T[cols * cols]; | |
Vt = new T[cols * cols]; | |
wbl::svd(mat, rows, cols, U, D, Vt); | |
int i; | |
for( i=0; i < rank; i++ ) | |
D[i * cols + i] = 1/D[i * cols + i]; | |
for( i=rank; i < cols; i++ ) | |
D[i * cols + i] = 0; | |
T *t; | |
t = new T[cols * cols]; | |
gemm<T>(1, true, Vt, cols, cols, false, D, cols, cols, 0, t); | |
gemm<T>(1, false, t, cols, cols, true, U, rows, cols, 0, inv); | |
delete[] U; | |
delete[] D; | |
delete[] Vt; | |
delete[] t; | |
} | |
template <class T> | |
void pinv(const T* mat, const int rows, const int cols, T* inv, const T thr) | |
{ | |
assert(rows >= cols); | |
T *U, *D, *Vt; | |
U = new T[rows * cols]; | |
D = new T[cols * cols]; | |
Vt = new T[cols * cols]; | |
wbl::svd(mat, rows, cols, U, D, Vt); | |
for( int i=0; i < cols; i++ ){ | |
D[i * cols + i] = D[i * cols + i] > thr ? (1/D[i * cols + i]) : 0; | |
} | |
T *t; | |
t = new T[cols * cols]; | |
gemm<T>(1, true, Vt, cols, cols, false, D, cols, cols, 0, t); | |
gemm<T>(1, false, t, cols, cols, true, U, rows, cols, 0, inv); | |
delete[] U; | |
delete[] D; | |
delete[] Vt; | |
delete[] t; | |
} | |
//double det(const double*, const int rows, const int cols); | |
//double trace(const double*, const int rows, const int cols); | |
//int rank(const double*, const int rows, const int cols, const double); | |
//void eig(const double*, const int rows, const int cols, double*, double*); | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment