Skip to content

Instantly share code, notes, and snippets.

@caub
Last active August 29, 2015 13:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save caub/8692363 to your computer and use it in GitHub Desktop.
Save caub/8692363 to your computer and use it in GitHub Desktop.
Kernel Recursive Least Square using JSAT https://code.google.com/p/java-statistical-analysis-tool/, and test on Santa Fe laser data

result

NMSE = 0.039
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.distributions.kernels.RBFKernel;
import jsat.linear.DenseVector;
import jsat.parameters.DoubleParameter;
import jsat.parameters.GridSearch;
import jsat.parameters.Parameterized;
import jsat.regression.KernelRLS;
import jsat.regression.RegressionDataSet;
import jsat.regression.RegressionModelEvaluation;
import jsat.regression.Regressor;
import jsat.utils.SystemInfo;
import org.math.plot.Plot2DPanel; //https://code.google.com/p/jmathplot/downloads/list
import javax.swing.*;
public class Pred {
// http://webee.technion.ac.il/people/rmeir/Publications/KrlsReport.pdf
static ExecutorService ex = Executors.newFixedThreadPool(SystemInfo.LogicalCores);
static int k = 40; // window size
static int n = 7; //number of iterations in learning
static double sigma = 0.9; //rbf kernel param
static double errTol = 0.01;
static KernelRLS krls;
static RegressionDataSet ds;
public static void main(String[] args) throws IOException, InterruptedException {
int trainingStart = 0;
int trainingEnd = 1000;
int testEnd = 1100;
//double[] ts = new double[testEnd];for (int i=0;i<ts.length;i++) ts[i] = Math.sin(i);
double[] ts = read("C:/Users/profit/A.dat", testEnd); //http://www-psych.stanford.edu/~andreas/Time-Series/SantaFe.html A.cont was appended to A.dat
ts = scale(ts, 0, 256, 0, 1);
double[] training = Arrays.copyOfRange(ts, trainingStart, trainingEnd);
double[][] data = makeLagMatrix(Arrays.copyOfRange(training, 0, training.length - 1), k);
//double[][] target = makeLagMatrix(Arrays.copyOfRange(training, k, training.length), 1);
ds = new RegressionDataSet(k, new CategoricalData[0]);
for (int i = 0; i < data.length; i++) {
ds.addDataPoint(new DenseVector(data[i]), new int[0], training[k + i]);
}
krls = new KernelRLS(new RBFKernel(sigma), errTol);
//gridSearch();
//System.out.println("cv error: " + getCrossValidationMeanError(10));
krls.train(ds, ex);
double[][] dat = data.clone();
for (int j = 1; j <= n; j++) {
System.out.println("step " + j + "/" + n);
double[][] dat_ = new double[dat.length-1][k];
for (int i =0; i<dat_.length; i++) {
double r = krls.regress(new DataPoint(new DenseVector(dat[i]), new int[0], null));
System.arraycopy(dat[i], 1, dat_[i], 0, k - 1);
dat_[i][k - 1] = r;
//System.out.println(str(data[i])+" --> "+target[i][0]);
ds.addDataPoint(new DenseVector(dat_[i]), new int[0], training[k + i + j]);
}
dat = dat_;
krls.train(ds, ex);
/*for (int i =0; i<dat_.length; i++) {
double r = krls.regress(new DataPoint(new DenseVector(dat[i]), new int[0], null));
System.arraycopy(dat[i], 1, dat_[i], 0, k - 1);
dat_[i][k - 1] = r;
//System.out.println(str(data[i])+" --> "+target[i][0]);
krls.update(new DataPoint(new DenseVector(dat_[i]), new int[0], null), training[k + i + j]);
}
dat = dat_;*/
}
double[] rec = new double[training.length-k];
for (int i = 0; i < rec.length; i++) {
rec[i] = krls.regress(new DataPoint(new DenseVector(data[i]), new int[0], null));
}
double[] targ = Arrays.copyOfRange(training, k, training.length);
//plot(rec, targ);
System.out.println("training nsme: " + getSquaredError(rec, targ)/(targ.length*variance(targ)));
double[] ideal = Arrays.copyOfRange(ts, trainingEnd, testEnd);
double[] forecast = new double[ideal.length];
double[] w = Arrays.copyOfRange(training, training.length - k, training.length);
for (int i = 0; i < forecast.length; i++) {
double r = krls.regress(new DataPoint(new DenseVector(w), new int[0], null));
System.arraycopy(w, 1, w, 0, k - 1);
w[k - 1] = r;
forecast[i] = r<0?0:r;
}
plot(forecast, ideal);
System.out.println("forecast nmse: " + getSquaredError(forecast, ideal)/(ideal.length*variance(ideal)));
}
public static void gridSearch() {
double[] sigmas = new double[]{.8, .85, .9, .95, 1.};
double[] errTols = new double[]{5e-3, 1e-2, 2e-2};
GridSearch gs = new GridSearch(krls, 20);//default params overriden by what follows
DoubleParameter paramK = (DoubleParameter) ((Parameterized) gs.getBaseRegressor()).getParameter("RBFKernel_sigma");
gs.addParameter(paramK, sigmas);
DoubleParameter paramE = (DoubleParameter) ((Parameterized) gs.getBaseRegressor()).getParameter("Error Tolerance");
gs.addParameter(paramE, errTols);
System.out.println("before: " + krls.getParameter("RBFKernel_sigma").getValueString() + " " + krls.getParameter("Error Tolerance").getValueString());
gs.train(ds, ex);
sigma = ((DoubleParameter) ((Parameterized) gs.getTrainedRegressor()).getParameter("RBFKernel_sigma")).getValue();
errTol = ((DoubleParameter) ((Parameterized) gs.getTrainedRegressor()).getParameter("Error Tolerance")).getValue();
((RBFKernel) krls.getKernelTrick()).setSigma(sigma);
krls.setErrorTolerance(errTol);
System.out.println("after: " + krls.getParameter("RBFKernel_sigma").getValueString() + " " + krls.getParameter("Error Tolerance").getValueString());
}
public static double getCrossValidationMeanError(int folds) {
RegressionModelEvaluation crossEval = new RegressionModelEvaluation(krls, ds);
crossEval.evaluateCrossValidation(folds);
return crossEval.getMeanError();
}
public static void plot(double[]... graphs) {
JFrame frame = new JFrame();
frame.setSize(1000, 600);
Plot2DPanel plot = new Plot2DPanel();
for (int i = 0; i < graphs.length; i++) {
plot.addLinePlot("", seq(1, graphs[i].length), graphs[i]);
}
frame.setContentPane(plot);
frame.setVisible(true);
}
public static double[][] makeLagMatrix(double[] a, int k) {
double[][] m = new double[a.length - k + 1][k];
for (int i = 0; i < m.length; i++)
m[i] = Arrays.copyOfRange(a, i, i + k);
return m;
}
public static double[] seq(int from, int to) {
double[] indexes = new double[to - from];
for (int i = from; i < to; i++) {
indexes[i - from] = i;
}
return indexes;
}
public static double[] mult(double[] a, double v) {
for (int i = 0; i < a.length; i++)
a[i] *= v;
return a;
}
public static double[] scale(double[] v, double min, double max, double newmin, double newmax) {
double[] r = new double[v.length];
double K = (newmax - newmin) / (max - min);
for (int i = 0; i < v.length; i++)
r[i] = newmin + K * (v[i] - min);
return r;
}
public static double getSquaredError(double[] vector1, double[] vector2) {
double squaredError = 0;
for (int i = 0; i < vector1.length; i++) {
squaredError += (vector1[i] - vector2[i]) * (vector1[i] - vector2[i]);
}
return squaredError;
}
public static double sum(double[] a) {
double s = 0;
for (double d : a)
s += d;
return s;
}
public static double mean(double[] a) {
return sum(a) / a.length;
}
public static double variance(double[] v) {
return variance(v, mean(v));
}
public static double variance(double[] v, double mean) {
double r = 0.0;
for (int i = 0; i < v.length; i++)
r += (v[i] - mean) * (v[i] - mean);
return r / (v.length - 1);
}
public static double sd(double[] v) {
return Math.sqrt(variance(v));
}
public static double sd(double[] v, double mean) {
return Math.sqrt(variance(v, mean));
}
public static String str(double[] a) {
String r = a[0] + "";
for (int i = 1; i < a.length; i++)
r += " " + a[i];
return r;
}
public static void print(double[] a) {
System.out.println(str(a));
}
public static double[] read(String filename, int len) throws IOException {
BufferedReader rdr = new BufferedReader(new FileReader(new File(filename)));
int i;
String line = "";
double[] d = new double[len];
for (i = 0; i < len && (line = rdr.readLine()) != null; i++) {
d[i] = Double.parseDouble(line);
}
rdr.close();
return d;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment