Skip to content

Instantly share code, notes, and snippets.

@myui
Last active August 29, 2015 14:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save myui/5bef4e4dc8b89d3819c6 to your computer and use it in GitHub Desktop.
Save myui/5bef4e4dc8b89d3819c6 to your computer and use it in GitHub Desktop.
public final class TrainNewsGroups {
public static void main(String[] args) throws IOException {
File base = new File(args[0]);
Multiset<String> overallCounts = HashMultiset.create();
int leakType = 0;
if (args.length > 1) {
leakType = Integer.parseInt(args[1]);
}
Dictionary newsGroups = new Dictionary();
NewsgroupHelper helper = new NewsgroupHelper();
helper.getEncoder().setProbes(2);
AdaptiveLogisticRegression learningAlgorithm =
new AdaptiveLogisticRegression(20, NewsgroupHelper.FEATURES, new L1());
learningAlgorithm.setInterval(800);
learningAlgorithm.setAveragingWindow(500);
List<File> files = Lists.newArrayList();
for (File newsgroup : base.listFiles()) {
if (newsgroup.isDirectory()) {
newsGroups.intern(newsgroup.getName());
files.addAll(Arrays.asList(newsgroup.listFiles()));
}
}
Collections.shuffle(files);
System.out.println(files.size() + " training files");
SGDInfo info = new SGDInfo();
int k = 0;
for (File file : files) {
String ng = file.getParentFile().getName();
int actual = newsGroups.intern(ng);
Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts);
learningAlgorithm.train(actual, v);
k++;
State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
SGDHelper.analyzeState(info, leakType, k, best);
}
learningAlgorithm.close();
SGDHelper.dissect(leakType, newsGroups, learningAlgorithm, files, overallCounts);
System.out.println("exiting main");
File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group.model");
ModelSerializer.writeBinary(modelFile.getAbsolutePath(),
learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
List<Integer> counts = Lists.newArrayList();
System.out.println("Word counts");
for (String count : overallCounts.elementSet()) {
counts.add(overallCounts.count(count));
}
Collections.sort(counts, Ordering.natural().reverse());
k = 0;
for (Integer count : counts) {
System.out.println(k + "\t" + count);
k++;
if (k > 1000) {
break;
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment