Created
January 3, 2013 21:25
-
-
Save dfilimon/4447431 to your computer and use it in GitHub Desktop.
trainActual()
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public static void trainActual(Iterable<Pair<Text, VectorWritable>> inputIterable, String outBase, | |
Map<String, Integer> clusterNamesToIds) throws IOException { | |
Map<String, Centroid> actualClusters = Maps.newHashMap(); | |
computeActualClusters(inputIterable, actualClusters); | |
OnlineLogisticRegression learningAlgorithm = | |
new OnlineLogisticRegression(NUM_CLASSES, NUM_FEATURES_ACTUAL, new L1()); | |
for (Pair<Text, VectorWritable> pair : inputIterable) { | |
Vector actualCentroid = pair.getSecond().get(); | |
Vector features = new DenseVector(NUM_FEATURES_ACTUAL); | |
int i = 0; | |
for (Centroid centroid : actualClusters.values()) { | |
features.set(i++, centroid.getDistanceSquared(actualCentroid)); | |
} | |
String clusterName = pair.getFirst().toString(); | |
System.out.printf("Feature vector of size %d; NUM_FEATURES_ACTUAL: %d; %d\n", | |
features.size(), NUM_FEATURES_ACTUAL, clusterNamesToIds.get(clusterName)); | |
learningAlgorithm.train(clusterNamesToIds.get(clusterName), features); | |
} | |
learningAlgorithm.close(); | |
ModelSerializer.writeBinary(outBase + "-append.model", learningAlgorithm); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment