Skip to content

Instantly share code, notes, and snippets.

@t-abe
Created July 10, 2011 10:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save t-abe/1074450 to your computer and use it in GitHub Desktop.
Save t-abe/1074450 to your computer and use it in GitHub Desktop.
Partial Least Squares
/********
// Multi-class
const int N; // number of samples
const int D; // dimension of feature
const int M; // number of classes
const int P; // number of reduced domensions
double* X = new double[D * N]; // features
double* Y = new double[M * N]; // labels
double* W = new double[D * P]; // projection matrix
double* C = new double[M * P]; // Y projection matrix
compute_weight_multiclass<double>(X, Y, N, D, M, P, W);
// projected features can be computed by W'X ("'" means "transpose")
// Y <- CW'X
// Single-class
const int N; // number of samples
const int D; // dimension of feature
const int P; // number of reduced domensions
double* X = new double[D * N]; // features
double* y = new double[1 * N]; // labels
double* W = new double[D * P]; // projection matrix
compute_weight<double>(X, y, N, D, P, W);
********/
#pragma once
#include <wbl.hpp>
#include <blas.hpp>
#include <cstdio>
// NIPALS convergence criterion
#define PLS_EPS 1e-6
#define PLS_QUIET false
#define MAT_VAL(mat, r, c, rows) \
(mat)[(r) + (c) * (rows)]
#ifndef ZERO
#define ZERO(mat, n) \
{ T _ZERO = 0; blas::copy(n, &_ZERO, 0, mat, 1); };
#endif
#define CROAK(msg) \
{ fprintf(stderr, "error: %s\n\tFile: %s, Line %d\n", msg, __FILE__,__LINE__); exit(1); };
#define CARP(msg) \
{ fprintf(stderr, "warn: %s\n\tFile: %s, Line %d\n", msg, __FILE__,__LINE__); };
namespace pls
{
template <class T>
void compute_weight_multiclass(
const T* X, const T* Y,
const int n, /* number of samples */
const int d, /* number of dimensions of a feature vector */
const int m, /* number of classes */
const int p, /* number of reduced dimensions */
T* W, T* C
)
{
T* E = new T[d * n];
T* F = new T[m * n];
blas::copy(d * n, X, 1, E, 1);
blas::copy(m * n, Y, 1, F, 1);
/* centerize E & F */
centerize(E, d, n);
centerize(F, m, n);
T* w = new T[d * 1];
T* w_new = new T[d * 1];
T* t = new T[n * 1];
T* c = new T[m * 1];
T* u = new T[n * 1];
T* v = new T[d * 1];
T* vt = new T[d * n];
T* ct = new T[m * n];
for( int q=0; q < p; q++ ){
// |E|, |F|
//printf("|E| = %g\n", blas::sum(d*n, E, 1));
//printf("|F| = %g\n", blas::sum(m*n, F, 1));
//printf("%g\t%g\n", blas::sum(d*n, E, 1), blas::sum(m*n, F, 1));
blas::copy(n, F, m, u, 1);
if( blas::nrm2(n, u, 1) == 0 ){
u[0]++; u[n-1]--;
}
ZERO(w, d);
while(1){
// (1) w = Eu, normalize w
wbl::gemm<T>(1, false, E, d, n, false, u, n, 1, 0, w_new);
blas::normalize(d, w_new, 1);
{ /* convergence check */
blas::axpy(d, -1, w, 1, w_new, 1);
if( !PLS_QUIET )
printf("%g\n", blas::nrm2(d, w_new, 1));
if( blas::nrm2(d, w_new, 1) < PLS_EPS )
break;
blas::axpy(d, 1, w_new, 1, w, 1);
}
// (2) t = E'w, normalize t
wbl::gemm<T>(1, true, E, d, n, false, w, d, 1, 0, t);
blas::normalize(n, t, 1);
// (3) c = Ft, normalize c
wbl::gemm<T>(1, false, F, m, n, false, t, n, 1, 0, c);
blas::normalize(m, c, 1);
// (4) u = F'c
wbl::gemm<T>(1, true, F, m, n, false, c, m, 1, 0, u);
}
// deflate E, E <- E - (Et)t'
wbl::gemm<T>(1, false, E, d, n, false, t, n, 1, 0, v);
wbl::gemm<T>(1, false, v, d, 1, true, t, n, 1, 0, vt);
blas::axpy(d * n, -1, vt, 1, E, 1);
// deflate F, F <- F - (Ft=c)t'
wbl::gemm<T>(1, false, F, m, n, false, t, n, 1, 0, c);
wbl::gemm<T>(1, false, c, m, 1, true, t, n, 1, 0, ct);
blas::axpy(m * n, -1, ct, 1, F, 1);
// copy w -> W
blas::copy(d, w, 1, &MAT_VAL(W, 0, q, d), 1);
// copy c -> C
if( C != NULL )
blas::copy(m, c, 1, &MAT_VAL(C, 0, q, m), 1);
}
delete[] v;
delete[] vt;
delete[] ct;
delete[] w;
delete[] t;
delete[] c;
delete[] u;
delete[] E;
delete[] F;
}
template <class T>
void compute_weight(
const T* X, const T* y,
const int n, /* number of samples */
const int d, /* number of dimensions of a feature vector */
const int p, /* number of reduced dimensions */
T* W
)
{
T* E = new T[d * n];
T* f = new T[1 * n];
blas::copy(d * n, X, 1, E, 1);
blas::copy(1 * n, y, 1, f, 1);
/* centerize E & f */
centerize(E, d, n);
centerize(f, 1, n);
T* w = new T[d * 1];
T* t = new T[n * 1];
T* v = new T[d * 1];
T* vt = new T[d * n];
for( int q=0; q < p; q++ ){
// (1) w = Ef', normalize w
wbl::gemm<float>(1, false, E, d, n, true, f, 1, n, 0, w);
blas::normalize(d, w, 1);
// (2) t = E'w, normalize t
wbl::gemm<float>(1, true, E, d, n, false, w, d, 1, 0, t);
blas::normalize(n, t, 1);
// (3) deflate E, E <- E - (Et)t'
wbl::gemm<float>(1, false, E, d, n, false, t, n, 1, 0, v);
wbl::gemm<float>(1, false, v, d, 1, true, t, n, 1, 0, vt);
blas::axpy(d * n, -1, vt, 1, E, 1);
// copy w -> W
blas::copy(d, w, 1, &MAT_VAL(W, 0, q, d), 1);
}
delete[] v;
delete[] vt;
delete[] w;
delete[] t;
delete[] E;
delete[] f;
}
template <class T>
void centerize(T* E, const int rows, const int cols)
{
std::vector<T> mE(rows, 0);
for( int c=0; c < cols; c++ )
blas::axpy(rows, 1, &E[rows*c], 1, &mE[0], 1);
blas::scal(rows, -1/(T)cols, &mE[0], 1);
for( int c=0; c < cols; c++ )
blas::axpy(rows, 1, &mE[0], 1, &E[rows*c], 1);
}
};
#undef ZERO
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment