Skip to content

Instantly share code, notes, and snippets.

@kazazes
Created February 3, 2016 16:26
Show Gist options
  • Save kazazes/16a7825221b3802e89a9 to your computer and use it in GitHub Desktop.
Save kazazes/16a7825221b3802e89a9 to your computer and use it in GitHub Desktop.
import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.persistentstorage.ConfigurationFactory;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.framework.machinelearning.regression.MatrixLinearRegression;
import com.datumbox.framework.machinelearning.regression.StepwiseRegression;
import com.SV.aggregate.DailyAggregate;
import org.joda.time.DateTime;
import java.util.ArrayList;
import java.util.Map;
public class Stepwise implements Regression {
private StepwiseRegression sw = null;
private Dataset trainingData = null;
private StepwiseRegression.TrainingParameters trainingParameters = null;
private DateTime simulatedTime = null;
private final DatabaseConfiguration dbConf = ConfigurationFactory.INMEMORY.getConfiguration();
public Stepwise (ArrayList<DailyAggregate> aggregates, ArrayList<String> allowedKeys, DateTime time) {
this.sw = new StepwiseRegression("DW", dbConf);
this.simulatedTime = time;
this.trainingData = new Dataset(dbConf);
for (DailyAggregate a : aggregates) {
Map<Object, Object> map = a.toMapWithAllowedKeys(allowedKeys);
AssociativeArray arr = new AssociativeArray(map);
Record r = new Record(arr, a.getValue());
this.trainingData.add(r);
}
trainingParameters = new StepwiseRegression.TrainingParameters();
trainingParameters.setRegressionClass(MatrixLinearRegression.class);
}
public RegressionModel generateModel() {
sw.fit(trainingData, trainingParameters);
return null;
}
public double predict(DailyAggregate aggregate, ArrayList<String> allowedKeys) {
Map<Object, Object> map = aggregate.toMapWithAllowedKeys(allowedKeys);
AssociativeArray arr = new AssociativeArray(map);
// If we're looking to predict Y here, what do we set the value to?
Record r = new Record(arr, 0d);
Dataset d = new Dataset(dbConf);
d.add(r);
sw.predict(d);
sw.getValidationMetrics();
double predicted = Double.parseDouble((String)r.getYPredicted());
return predicted;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment