Created
October 20, 2013 08:00
-
-
Save mountcedar/7066243 to your computer and use it in GitHub Desktop.
PSVMのサンプルコード@Processing2 ref: http://qiita.com/mountcedar/items/a55889b1ce5fa0d4e20e
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
SVMComponent svm = null; | |
void setup () { | |
svm = new SVMComponent(this); | |
svm.importDataFromFile(sketchPath + "/points.csv", 2); | |
svm.train(); | |
svm.save(sketchPath + "/trained.txt"); | |
} | |
void draw () { | |
exit(); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import psvm.*; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.util.List; | |
import java.util.ArrayList; | |
import java.util.Set; | |
import java.util.HashSet; | |
import java.util.Arrays; | |
import java.lang.Class; | |
import java.lang.reflect.Field; | |
import processing.data.Table; | |
import processing.data.TableRow; | |
import processing.core.PApplet; | |
import java.io.File; | |
class SVMComponent { | |
protected static Logger logger = LoggerFactory.getLogger(SVMComponent.class); | |
protected PApplet applet = null; | |
protected int [] labels = null; | |
protected float [][] trainingPoints = null; | |
protected SVM model = null; | |
protected SVMProblem problem = null; | |
public SVMComponent (PApplet applet) { | |
this.applet = applet; | |
model = new SVM(applet); | |
problem = new SVMProblem(); | |
} | |
public boolean save (String modelPath) { | |
try { | |
this.model.saveModel(modelPath); | |
return true; | |
} catch (Exception e) { | |
logger.error("Error: {}", e); | |
return false; | |
} | |
} | |
public boolean load (String modelPath, Class<?> clazz) { | |
try { | |
Field [] fields = clazz.getFields(); | |
this.model.loadModel(modelPath, fields.length - 1); | |
return true; | |
} catch (Exception e) { | |
logger.error("Error: {}", e); | |
return false; | |
} | |
} | |
public boolean importDataFromFile (String filepath, int labelIndex) { | |
try { | |
Table data = new Table(new File(filepath)); | |
int featureNum = data.getColumnCount() - 1; | |
this.trainingPoints = new float[data.getRowCount()][featureNum]; | |
this.labels = new int[data.getRowCount()]; | |
float [] maximums = new float [featureNum]; // for normalization | |
float [] minimums = new float [featureNum]; // for normalization | |
for (int i = 0; i < featureNum; i++) minimums[i] = Float.MAX_VALUE; | |
int i = 0; | |
for (TableRow row : data.rows()) { | |
float[] p = new float[featureNum]; | |
boolean aboveLabel = false; | |
for (int j = 0; j < data.getColumnCount(); j++) { | |
if (j == labelIndex) { | |
labels[i] = row.getInt(j); | |
aboveLabel = true; | |
} | |
int index = aboveLabel ? j - 1: j; | |
float value = row.getFloat(index); | |
logger.error("index: {}", index); | |
if (value > maximums[index]) maximums[index] = value; | |
if (value < minimums[index]) minimums[index] = value; | |
p[index] = value; | |
} | |
this.trainingPoints[i] = p; | |
i++; | |
} | |
// normalization | |
for (i = 0; i < data.getRowCount(); i++) { | |
for (int j = 0; j < data.getColumnCount(); j++) { | |
this.trainingPoints[i][j] -= minimums[j]; | |
this.trainingPoints[i][j] /= maximums[j] - minimums[j]; | |
} | |
} | |
this.problem.setNumFeatures(featureNum); | |
return true; | |
} catch (Exception e) { | |
logger.error ("Error: {}", e); | |
return false; | |
} | |
} | |
protected boolean train () { | |
try { | |
problem.setSampleData(this.labels, this.trainingPoints); | |
this.model.train(problem); | |
return true; | |
} catch (Exception e) { | |
logger.error("Error: {}", e.getMessage()); | |
return false; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment