Skip to content

Instantly share code, notes, and snippets.

@toshi-k
Created September 6, 2015 11:48
Show Gist options
  • Save toshi-k/8257b58303177a5c6b21 to your computer and use it in GitHub Desktop.
Save toshi-k/8257b58303177a5c6b21 to your computer and use it in GitHub Desktop.
Learned Iterative Shrinkage-Thresholding Algorithm (Rcpp)
#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
using namespace Rcpp;
using namespace arma;
using namespace std;
vec soft(vec x, double theta){
vec vec2 = abs(x) - theta;
uvec oth = uvec(vec2 > 0);
return vec2 % oth % sign(x);
}
NumericVector lista(NumericVector y, NumericMatrix S, NumericMatrix We,
double theta, int T){
vec y_ = as<vec>(y);
vec x = zeros<vec>(We.nrow());
mat We_ = as<mat>(We);
mat S_ = as<mat>(S);
vec WeX = We_ * y_;
for(int itr=0; itr<T; itr++){
x = S_ * x + WeX;
x = soft(x, theta);
}
return wrap(x);
}
double calc_obj(NumericMatrix Y, NumericMatrix Z, NumericMatrix S, NumericMatrix We,
double theta, int T){
double value = 0;
NumericVector dif(Z.nrow());
int i_start = 0;
for(int i=i_start; i<Y.ncol(); i++){
dif = Z(_,i) - lista(Y(_,i), S, We, theta, T);
value += 0.5 * inner_product(dif.begin(), dif.end(), dif.begin(), 0.0);
}
value = value / ( Y.ncol() - i_start + 1 );
return value;
}
umat hdash(vec x, double theta){
umat mathd = diagmat(soft(x, theta)) != 0;
return mathd;
}
// [[Rcpp::export]]
List lista_train(NumericMatrix Y, NumericMatrix Z, List Pf, int T,
int itr_times, double k){
NumericMatrix We = Pf[0];
NumericMatrix S = Pf[1];
double theta = Pf[2], k2;
mat Z_ = as<mat>(Z);
mat S_ = as<mat>(S);
mat We_ = as<mat>(We);
double obj_value = calc_obj(Y, Z, S, We, theta, T);
printf("Objective Value: %.5lf \n", obj_value);
for(int itr=1; itr<=itr_times; itr++){
k2 = k / sqrt(itr);
for(int i=0; i<Y.ncol(); i++){
NumericVector y = Y(_,i);
vec X = as<vec>(y);
mat Zs = mat(We_.n_rows, T + 1); // 0 ... T
mat Cs = mat(We_.n_rows, T); // 1 ... T
// - - - - - fprop - - - - -
vec B = We_ * X;
Zs.col(0) = soft(B, theta);
for(int t=0; t<T; t++){
Cs.col(t) = B + S_ * Zs.col(t);
Zs.col(t+1) = soft(Cs.col(t), theta);
}
// - - - - - bprop - - - - -
vec deltaB = zeros<vec>(We_.n_rows);
mat deltaS = zeros<mat>(S_.n_rows, S_.n_cols);
vec deltaZ = Zs.col(T) - Z_.col(i);
double deltatheta = 0;
for(int t=T-1; t>=0; t--){
// vec deltaC = hdash(Cs.col(t), theta) * deltaZ;
vec deltaC = uvec( Cs.col(t) - theta > 0 ) % deltaZ;
deltatheta -= as_scalar( sign(Cs.col(t)).t() * deltaC );
deltaB += deltaC;
deltaS += deltaC * Zs.col(t).t();
deltaZ = S_.t() * deltaC;
}
// deltaB += hdash(B, theta) * deltaZ;
uvec hdashvec = uvec(abs(B) - theta > 0);
deltaB += hdashvec % deltaZ;
// deltatheta -= as_scalar(sign(B).t() * hdash(B, theta) * deltaZ);
deltatheta -= as_scalar(sign(B).t() * (hdashvec % deltaZ));
mat deltaWe = deltaB * X.t();
// - - - - - parameter update - - - - -
We_ -= k2 * deltaWe;
S_ -= k2 * deltaS;
theta -= k2 / 10 * deltatheta;
}
obj_value = calc_obj(Y, Z, wrap(S_), wrap(We_), theta, T);
printf("Objective Value: %.5lf \n", obj_value);
}
Pf[0] = wrap(We_);
Pf[1] = wrap(S_);
Pf[2] = theta;
return Pf;
}
/**
* <<References>>
* [1] Learning Fast Approximations of Sparse Coding
* Karol Gregor and Yann LeCun
* http://yann.lecun.com/exdb/publis/pdf/gregor-icml-10.pdf
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment