Skip to content

Instantly share code, notes, and snippets.

@aron-bordin
Created May 4, 2015 00:21
Show Gist options
  • Save aron-bordin/0a3c13a508b5246702be to your computer and use it in GitHub Desktop.
Save aron-bordin/0a3c13a508b5246702be to your computer and use it in GitHub Desktop.
Sentiment NLP with Mallet
package com.aronbordin;
import ca.uwo.csd.ai.nlp.kernel.LinearKernel;
import ca.uwo.csd.ai.nlp.mallet.libsvm.SVMClassifierTrainer;
import cc.mallet.classify.*;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.CsvIterator;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.Randoms;
import java.io.*;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.regex.Pattern;
public class Sentiment {
protected Pipe pipe;
protected InstanceList trainInstances;
protected static final String FILE_TRAIN_X = "data/train_labeled.tsv";
protected static final String FILE_TEST = "data/test_data.tsv";
protected static final String FILE_TRAIN_X_BIN = "data/train_labeled.bin";
protected static final String FILE_CLASSIFIER = "data/classifier.bin";
protected static final String FILE_TEST_OUT = "data/test.out";
Sentiment() {
pipe = buildPipe();
}
protected void importFile() {
CsvIterator iter = null;
try {
iter = new CsvIterator(FILE_TRAIN_X, "\"(\\w+)\"\\s+(\\d)\\s+(.*)", 3, 2, 1);
} catch (FileNotFoundException e) {
e.printStackTrace();
}
trainInstances = new InstanceList(pipe);
trainInstances.addThruPipe(iter);
}
private void saveFile() {
trainInstances.save(new File(FILE_TRAIN_X_BIN));
}
protected Pipe buildPipe() {
ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
pipeList.add(new Input2CharSequence("UTF-8"));
Pattern patternToken = Pattern.compile("[\\p{L}\\p{N}_]+");
pipeList.add(new CharSequence2TokenSequence(patternToken));
pipeList.add(new TokenSequenceLowercase());
pipeList.add(new TokenSequenceRemoveStopwords(false, false));
// pipeList.add(new TokenSequenceRemoveNonAlpha());
pipeList.add(new TokenSequence2FeatureSequence());
pipeList.add(new Target2Label());
pipeList.add(new FeatureSequence2FeatureVector());
return new SerialPipes(pipeList);
}
protected void train() {
InstanceList[] splited_data = trainInstances.split(
new Randoms(),
new double[]{0.7, 0.3}
);
ClassifierTrainer trainer;
Classifier classifier;
Trial trial;
/* System.out.println("Training with MaxEntTrainer...");
trainer = new MaxEntTrainer();
classifier = trainer.train(splited_data[0]);
System.out.print("Done! ");
trial = new Trial(classifier, splited_data[1]);
System.out.println("Accuracy: " + trial.getAccuracy());
System.out.println("Training with NaiveBayesTrainer...");
trainer = new NaiveBayesTrainer();
classifier = trainer.train(splited_data[0]);
System.out.print("Done! ");
trial = new Trial(classifier, splited_data[1]);
System.out.println("Accuracy: " + trial.getAccuracy());*/
System.out.println("Training with SVMClassifierTrainer...");
trainer = new SVMClassifierTrainer(new LinearKernel());
classifier = trainer.train(splited_data[0]);
System.out.print("Done! ");
trial = new Trial(classifier, splited_data[1]);
System.out.println("Accuracy: " + trial.getAccuracy());
}
protected void trainAndPredict() {
ClassifierTrainer trainer;
Classifier classifier = null;
try {
if (new File(FILE_CLASSIFIER).exists()) {
ObjectInputStream obj = new ObjectInputStream(new FileInputStream(new File(FILE_CLASSIFIER)));
classifier = (Classifier) obj.readObject();
obj.close();
} else {
// System.out.println("Training with MaxEnt...");
// trainer = new MaxEntTrainer();
// classifier = trainer.train(trainInstances);
System.out.println("Training with NaiveBayesTrainer...");
trainer = new NaiveBayesTrainer();
classifier = trainer.train(trainInstances);
ObjectOutputStream obj = new ObjectOutputStream(new FileOutputStream(new File(FILE_CLASSIFIER)));
obj.writeObject(classifier);
obj.close();
}
} catch (IOException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
CsvIterator iterTest = null;
try {
iterTest = new CsvIterator(FILE_TEST, "\"(\\w+)\"\\s+(.*)", 2, 0, 1);
} catch (FileNotFoundException e) {
e.printStackTrace();
}
Iterator<Instance> testInstances = classifier.getInstancePipe().newIteratorFrom(iterTest);
FileWriter csv = null;
try {
csv = new FileWriter(new File(FILE_TEST_OUT));
csv.append("\"id\",\"sentiment\"\n");
while (testInstances.hasNext()) {
Instance i = testInstances.next();
System.out.println(classifier.classify(i));
csv.append("\"" + i.getName() + "\",");
csv.append(classifier.classify(i).getLabelVector().getBestLabel().toString());
csv.append("\n");
}
csv.flush();
csv.close();
} catch (IOException e) {
e.printStackTrace();
}
}
protected void readFile() {
trainInstances = trainInstances.load(new File(FILE_TRAIN_X_BIN));
}
public static void main(String a[]) {
System.out.println("Sentiment analysis\n\n\n");
System.out.println("Reading data...");
Sentiment snt = new Sentiment();
if (new File(FILE_TRAIN_X_BIN).exists()) {
snt.readFile();
} else {
System.out.println("\tVectoring labeled sentences...");
snt.importFile();
snt.saveFile();
System.out.println("\tDone!");
}
snt.train();
System.out.println("Done!");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment