Skip to content

Instantly share code, notes, and snippets.

@t-abe
Created July 10, 2011 10:25
Show Gist options
  • Save t-abe/1074446 to your computer and use it in GitHub Desktop.
Save t-abe/1074446 to your computer and use it in GitHub Desktop.
BLASとLAPACKのラッパー
/* the wrapper of BLAS and LAPACK (MKL) */
#pragma once
#include "blas.hpp"
#include <cstring>
#ifndef MAX
#define MAX(a, b) (((a) > (b)) ? (a) : (b))
#endif
#ifndef MIN
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
#endif
#ifndef ZERO
#define ZERO(mat, n) \
{ T _ZERO = 0; blas::copy(n, &_ZERO, 0, mat, 1); };
#endif
namespace wbl
{
// X := x or x' (depends on transx)
// Y := y or y' (depends on transy)
// z <- alpha * XY + beta * z
template <class T>
void gemm(const T alpha,
const bool transx, const T* x, const int xrows, const int xcols,
const bool transy, const T* y, const int yrows, const int ycols,
const T beta, T* z
)
{
int m, n, k;
int lda, ldb, ldc;
m = !transx ? xrows : xcols;
n = !transy ? ycols : yrows;
k = !transx ? xcols : xrows; // = y.rows : y.cols
lda = xrows;
ldb = yrows;
ldc = m;
blas::gemm(
transx, transy,
m, n, k,
alpha, x, lda,
y, ldb,
beta, z, ldc);
};
// tr_mat <- mat'
template <class T>
void transpose(const T* mat, const int rows, const int cols, T* tr_mat)
{
for(int i=0; i < cols; i++)
blas::copy(rows, &mat[rows * i], 1, &tr_mat[i], cols);
};
// tmat <- mat^-1
template <class T>
void inverse(const T* mat, const int rows, const int cols, T* tmat)
{
assert(rows == cols);
blas::copy(rows * cols, mat, 1, tmat, 1);
// LU decomposition
int info;
int *ipiv = new int[rows];
blas::getrf(rows, rows, tmat, rows, ipiv, &info);
// inverse
int lwork = rows * 64;
T* work = new T[lwork];
blas::getri(rows, tmat, rows, ipiv,
work, lwork, &info);
delete[] ipiv;
delete[] work;
};
// u(r, c), d(c, c), vt(c, c)
template <class T>
void svd(const T* mat, const int rows, const int cols, T* u, T* d, T* vt)
{
int info, i;
blas::copy(rows * cols, mat, 1, u, 1);
ZERO(d, cols * cols);
ZERO(vt, cols * cols);
int lwork = MAX( 3 * cols + rows,
5 * cols );
T* work = new T[lwork];
T* vd = new T[cols];
ZERO(vd, cols);
if( rows < cols ){
blas::gesvd('O', 'S', rows, cols, u, rows,
vd, NULL, 1, vt, cols,
work, lwork, &info);
for(i=0; i < cols - rows; i++)
memset(&u[rows * (rows + i)], 0, sizeof(T) * rows);
} else {
blas::gesvd('O', 'S', rows, cols, u, rows,
vd, NULL, 1, vt, cols,
work, lwork, &info);
}
for(i=0; i < cols; i++)
d[i * cols + i] = vd[i];
delete[] vd;
delete[] work;
}
template <class T>
void pinv(const T* mat, const int rows, const int cols, T* inv, int rank = -1)
{
assert(rows >= cols);
if( rank == -1 ) rank = cols;
T *U, *D, *Vt;
U = new T[rows * cols];
D = new T[cols * cols];
Vt = new T[cols * cols];
wbl::svd(mat, rows, cols, U, D, Vt);
int i;
for( i=0; i < rank; i++ )
D[i * cols + i] = 1/D[i * cols + i];
for( i=rank; i < cols; i++ )
D[i * cols + i] = 0;
T *t;
t = new T[cols * cols];
gemm<T>(1, true, Vt, cols, cols, false, D, cols, cols, 0, t);
gemm<T>(1, false, t, cols, cols, true, U, rows, cols, 0, inv);
delete[] U;
delete[] D;
delete[] Vt;
delete[] t;
}
template <class T>
void pinv(const T* mat, const int rows, const int cols, T* inv, const T thr)
{
assert(rows >= cols);
T *U, *D, *Vt;
U = new T[rows * cols];
D = new T[cols * cols];
Vt = new T[cols * cols];
wbl::svd(mat, rows, cols, U, D, Vt);
for( int i=0; i < cols; i++ ){
D[i * cols + i] = D[i * cols + i] > thr ? (1/D[i * cols + i]) : 0;
}
T *t;
t = new T[cols * cols];
gemm<T>(1, true, Vt, cols, cols, false, D, cols, cols, 0, t);
gemm<T>(1, false, t, cols, cols, true, U, rows, cols, 0, inv);
delete[] U;
delete[] D;
delete[] Vt;
delete[] t;
}
//double det(const double*, const int rows, const int cols);
//double trace(const double*, const int rows, const int cols);
//int rank(const double*, const int rows, const int cols, const double);
//void eig(const double*, const int rows, const int cols, double*, double*);
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment