Skip to content

Instantly share code, notes, and snippets.

@mountcedar
Created October 20, 2013 08:00
Show Gist options
  • Save mountcedar/7066243 to your computer and use it in GitHub Desktop.
Save mountcedar/7066243 to your computer and use it in GitHub Desktop.
PSVMのサンプルコード@Processing2 ref: http://qiita.com/mountcedar/items/a55889b1ce5fa0d4e20e
SVMComponent svm = null;
void setup () {
svm = new SVMComponent(this);
svm.importDataFromFile(sketchPath + "/points.csv", 2);
svm.train();
svm.save(sketchPath + "/trained.txt");
}
void draw () {
exit();
}
import psvm.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.ArrayList;
import java.util.Set;
import java.util.HashSet;
import java.util.Arrays;
import java.lang.Class;
import java.lang.reflect.Field;
import processing.data.Table;
import processing.data.TableRow;
import processing.core.PApplet;
import java.io.File;
class SVMComponent {
protected static Logger logger = LoggerFactory.getLogger(SVMComponent.class);
protected PApplet applet = null;
protected int [] labels = null;
protected float [][] trainingPoints = null;
protected SVM model = null;
protected SVMProblem problem = null;
public SVMComponent (PApplet applet) {
this.applet = applet;
model = new SVM(applet);
problem = new SVMProblem();
}
public boolean save (String modelPath) {
try {
this.model.saveModel(modelPath);
return true;
} catch (Exception e) {
logger.error("Error: {}", e);
return false;
}
}
public boolean load (String modelPath, Class<?> clazz) {
try {
Field [] fields = clazz.getFields();
this.model.loadModel(modelPath, fields.length - 1);
return true;
} catch (Exception e) {
logger.error("Error: {}", e);
return false;
}
}
public boolean importDataFromFile (String filepath, int labelIndex) {
try {
Table data = new Table(new File(filepath));
int featureNum = data.getColumnCount() - 1;
this.trainingPoints = new float[data.getRowCount()][featureNum];
this.labels = new int[data.getRowCount()];
float [] maximums = new float [featureNum]; // for normalization
float [] minimums = new float [featureNum]; // for normalization
for (int i = 0; i < featureNum; i++) minimums[i] = Float.MAX_VALUE;
int i = 0;
for (TableRow row : data.rows()) {
float[] p = new float[featureNum];
boolean aboveLabel = false;
for (int j = 0; j < data.getColumnCount(); j++) {
if (j == labelIndex) {
labels[i] = row.getInt(j);
aboveLabel = true;
}
int index = aboveLabel ? j - 1: j;
float value = row.getFloat(index);
logger.error("index: {}", index);
if (value > maximums[index]) maximums[index] = value;
if (value < minimums[index]) minimums[index] = value;
p[index] = value;
}
this.trainingPoints[i] = p;
i++;
}
// normalization
for (i = 0; i < data.getRowCount(); i++) {
for (int j = 0; j < data.getColumnCount(); j++) {
this.trainingPoints[i][j] -= minimums[j];
this.trainingPoints[i][j] /= maximums[j] - minimums[j];
}
}
this.problem.setNumFeatures(featureNum);
return true;
} catch (Exception e) {
logger.error ("Error: {}", e);
return false;
}
}
protected boolean train () {
try {
problem.setSampleData(this.labels, this.trainingPoints);
this.model.train(problem);
return true;
} catch (Exception e) {
logger.error("Error: {}", e.getMessage());
return false;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment