Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save dfilimon/4447431 to your computer and use it in GitHub Desktop.
Save dfilimon/4447431 to your computer and use it in GitHub Desktop.
trainActual()
public static void trainActual(Iterable<Pair<Text, VectorWritable>> inputIterable, String outBase,
Map<String, Integer> clusterNamesToIds) throws IOException {
Map<String, Centroid> actualClusters = Maps.newHashMap();
computeActualClusters(inputIterable, actualClusters);
OnlineLogisticRegression learningAlgorithm =
new OnlineLogisticRegression(NUM_CLASSES, NUM_FEATURES_ACTUAL, new L1());
for (Pair<Text, VectorWritable> pair : inputIterable) {
Vector actualCentroid = pair.getSecond().get();
Vector features = new DenseVector(NUM_FEATURES_ACTUAL);
int i = 0;
for (Centroid centroid : actualClusters.values()) {
features.set(i++, centroid.getDistanceSquared(actualCentroid));
}
String clusterName = pair.getFirst().toString();
System.out.printf("Feature vector of size %d; NUM_FEATURES_ACTUAL: %d; %d\n",
features.size(), NUM_FEATURES_ACTUAL, clusterNamesToIds.get(clusterName));
learningAlgorithm.train(clusterNamesToIds.get(clusterName), features);
}
learningAlgorithm.close();
ModelSerializer.writeBinary(outBase + "-append.model", learningAlgorithm);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment