Created
July 10, 2011 10:27
-
-
Save t-abe/1074450 to your computer and use it in GitHub Desktop.
Partial Least Squares
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/******** | |
// Multi-class | |
const int N; // number of samples | |
const int D; // dimension of feature | |
const int M; // number of classes | |
const int P; // number of reduced domensions | |
double* X = new double[D * N]; // features | |
double* Y = new double[M * N]; // labels | |
double* W = new double[D * P]; // projection matrix | |
double* C = new double[M * P]; // Y projection matrix | |
compute_weight_multiclass<double>(X, Y, N, D, M, P, W); | |
// projected features can be computed by W'X ("'" means "transpose") | |
// Y <- CW'X | |
// Single-class | |
const int N; // number of samples | |
const int D; // dimension of feature | |
const int P; // number of reduced domensions | |
double* X = new double[D * N]; // features | |
double* y = new double[1 * N]; // labels | |
double* W = new double[D * P]; // projection matrix | |
compute_weight<double>(X, y, N, D, P, W); | |
********/ | |
#pragma once | |
#include <wbl.hpp> | |
#include <blas.hpp> | |
#include <cstdio> | |
// NIPALS convergence criterion | |
#define PLS_EPS 1e-6 | |
#define PLS_QUIET false | |
#define MAT_VAL(mat, r, c, rows) \ | |
(mat)[(r) + (c) * (rows)] | |
#ifndef ZERO | |
#define ZERO(mat, n) \ | |
{ T _ZERO = 0; blas::copy(n, &_ZERO, 0, mat, 1); }; | |
#endif | |
#define CROAK(msg) \ | |
{ fprintf(stderr, "error: %s\n\tFile: %s, Line %d\n", msg, __FILE__,__LINE__); exit(1); }; | |
#define CARP(msg) \ | |
{ fprintf(stderr, "warn: %s\n\tFile: %s, Line %d\n", msg, __FILE__,__LINE__); }; | |
namespace pls | |
{ | |
template <class T> | |
void compute_weight_multiclass( | |
const T* X, const T* Y, | |
const int n, /* number of samples */ | |
const int d, /* number of dimensions of a feature vector */ | |
const int m, /* number of classes */ | |
const int p, /* number of reduced dimensions */ | |
T* W, T* C | |
) | |
{ | |
T* E = new T[d * n]; | |
T* F = new T[m * n]; | |
blas::copy(d * n, X, 1, E, 1); | |
blas::copy(m * n, Y, 1, F, 1); | |
/* centerize E & F */ | |
centerize(E, d, n); | |
centerize(F, m, n); | |
T* w = new T[d * 1]; | |
T* w_new = new T[d * 1]; | |
T* t = new T[n * 1]; | |
T* c = new T[m * 1]; | |
T* u = new T[n * 1]; | |
T* v = new T[d * 1]; | |
T* vt = new T[d * n]; | |
T* ct = new T[m * n]; | |
for( int q=0; q < p; q++ ){ | |
// |E|, |F| | |
//printf("|E| = %g\n", blas::sum(d*n, E, 1)); | |
//printf("|F| = %g\n", blas::sum(m*n, F, 1)); | |
//printf("%g\t%g\n", blas::sum(d*n, E, 1), blas::sum(m*n, F, 1)); | |
blas::copy(n, F, m, u, 1); | |
if( blas::nrm2(n, u, 1) == 0 ){ | |
u[0]++; u[n-1]--; | |
} | |
ZERO(w, d); | |
while(1){ | |
// (1) w = Eu, normalize w | |
wbl::gemm<T>(1, false, E, d, n, false, u, n, 1, 0, w_new); | |
blas::normalize(d, w_new, 1); | |
{ /* convergence check */ | |
blas::axpy(d, -1, w, 1, w_new, 1); | |
if( !PLS_QUIET ) | |
printf("%g\n", blas::nrm2(d, w_new, 1)); | |
if( blas::nrm2(d, w_new, 1) < PLS_EPS ) | |
break; | |
blas::axpy(d, 1, w_new, 1, w, 1); | |
} | |
// (2) t = E'w, normalize t | |
wbl::gemm<T>(1, true, E, d, n, false, w, d, 1, 0, t); | |
blas::normalize(n, t, 1); | |
// (3) c = Ft, normalize c | |
wbl::gemm<T>(1, false, F, m, n, false, t, n, 1, 0, c); | |
blas::normalize(m, c, 1); | |
// (4) u = F'c | |
wbl::gemm<T>(1, true, F, m, n, false, c, m, 1, 0, u); | |
} | |
// deflate E, E <- E - (Et)t' | |
wbl::gemm<T>(1, false, E, d, n, false, t, n, 1, 0, v); | |
wbl::gemm<T>(1, false, v, d, 1, true, t, n, 1, 0, vt); | |
blas::axpy(d * n, -1, vt, 1, E, 1); | |
// deflate F, F <- F - (Ft=c)t' | |
wbl::gemm<T>(1, false, F, m, n, false, t, n, 1, 0, c); | |
wbl::gemm<T>(1, false, c, m, 1, true, t, n, 1, 0, ct); | |
blas::axpy(m * n, -1, ct, 1, F, 1); | |
// copy w -> W | |
blas::copy(d, w, 1, &MAT_VAL(W, 0, q, d), 1); | |
// copy c -> C | |
if( C != NULL ) | |
blas::copy(m, c, 1, &MAT_VAL(C, 0, q, m), 1); | |
} | |
delete[] v; | |
delete[] vt; | |
delete[] ct; | |
delete[] w; | |
delete[] t; | |
delete[] c; | |
delete[] u; | |
delete[] E; | |
delete[] F; | |
} | |
template <class T> | |
void compute_weight( | |
const T* X, const T* y, | |
const int n, /* number of samples */ | |
const int d, /* number of dimensions of a feature vector */ | |
const int p, /* number of reduced dimensions */ | |
T* W | |
) | |
{ | |
T* E = new T[d * n]; | |
T* f = new T[1 * n]; | |
blas::copy(d * n, X, 1, E, 1); | |
blas::copy(1 * n, y, 1, f, 1); | |
/* centerize E & f */ | |
centerize(E, d, n); | |
centerize(f, 1, n); | |
T* w = new T[d * 1]; | |
T* t = new T[n * 1]; | |
T* v = new T[d * 1]; | |
T* vt = new T[d * n]; | |
for( int q=0; q < p; q++ ){ | |
// (1) w = Ef', normalize w | |
wbl::gemm<float>(1, false, E, d, n, true, f, 1, n, 0, w); | |
blas::normalize(d, w, 1); | |
// (2) t = E'w, normalize t | |
wbl::gemm<float>(1, true, E, d, n, false, w, d, 1, 0, t); | |
blas::normalize(n, t, 1); | |
// (3) deflate E, E <- E - (Et)t' | |
wbl::gemm<float>(1, false, E, d, n, false, t, n, 1, 0, v); | |
wbl::gemm<float>(1, false, v, d, 1, true, t, n, 1, 0, vt); | |
blas::axpy(d * n, -1, vt, 1, E, 1); | |
// copy w -> W | |
blas::copy(d, w, 1, &MAT_VAL(W, 0, q, d), 1); | |
} | |
delete[] v; | |
delete[] vt; | |
delete[] w; | |
delete[] t; | |
delete[] E; | |
delete[] f; | |
} | |
template <class T> | |
void centerize(T* E, const int rows, const int cols) | |
{ | |
std::vector<T> mE(rows, 0); | |
for( int c=0; c < cols; c++ ) | |
blas::axpy(rows, 1, &E[rows*c], 1, &mE[0], 1); | |
blas::scal(rows, -1/(T)cols, &mE[0], 1); | |
for( int c=0; c < cols; c++ ) | |
blas::axpy(rows, 1, &mE[0], 1, &E[rows*c], 1); | |
} | |
}; | |
#undef ZERO |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment