Skip to content

Instantly share code, notes, and snippets.

@thomasjungblut
Created April 5, 2013 11:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomasjungblut/5318761 to your computer and use it in GitHub Desktop.
Save thomasjungblut/5318761 to your computer and use it in GitHub Desktop.
simple pos tagger using HMM with ~ 91.82% accuracy with a small trainingset of 70k words and 10k test words.
package de.jungblut.ml;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.Lists;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleVector;
import de.jungblut.math.tuple.Tuple;
import de.jungblut.ner.SequenceFeatureExtractor;
import de.jungblut.ner.SparseFeatureExtractorHelper;
import de.jungblut.nlp.HMM;
public class POSTagger {
static BiMap<String, Integer> indexTag = HashBiMap.create();
static int id = 0;
static {
indexTag.put("SENTENCE_BEGIN", id++);
}
static boolean printMisprediction = false;
public static void main(String[] args) throws IOException {
String inputFile = "files/pos/training.pos";
String inputDevelopmentFile = "files/pos/development.pos";
List<String> words = new ArrayList<>();
List<Integer> labels = new ArrayList<>();
read(inputFile, words, labels);
List<String> developmentWords = new ArrayList<>();
List<Integer> developmentLabels = new ArrayList<>();
read(inputDevelopmentFile, developmentWords, developmentLabels);
SparseFeatureExtractorHelper<String> extractorHelper = new SparseFeatureExtractorHelper<>(
words, labels, new POSExtractor());
Tuple<DoubleVector[], DenseDoubleVector[]> vectorize = extractorHelper
.vectorize();
String[] dictionary = extractorHelper.getDictionary();
DoubleVector[] features = vectorize.getFirst();
DenseDoubleVector[] state = vectorize.getSecond();
HMM hmm = new HMM(dictionary.length, indexTag.size());
hmm.trainSupervised(features, state);
Tuple<DoubleVector[], DenseDoubleVector[]> vectorizeAdditionals = extractorHelper
.vectorizeAdditionals(developmentWords, developmentLabels);
DoubleVector[] testFeatures = vectorizeAdditionals.getFirst();
DenseDoubleVector[] testLabels = vectorizeAdditionals.getSecond();
int correct = 0;
DoubleVector lastPrediction = new SparseDoubleVector(indexTag.size());
lastPrediction.set(0, 1d);
for (int i = 0; i < testFeatures.length; i++) {
DoubleVector feat = testFeatures[i];
DenseDoubleVector outcome = testLabels[i];
DoubleVector predicted = hmm.predict(feat, lastPrediction);
int predictedHiddenState = predicted.maxIndex();
if (predictedHiddenState == outcome.maxIndex()) {
correct++;
} else if (printMisprediction) {
System.out
.println("\"" + developmentWords.get(i) + "\" -> Predicted: \""
+ indexTag.inverse().get(predictedHiddenState)
+ "\" But should be: "
+ indexTag.inverse().get(outcome.maxIndex()));
}
lastPrediction = predicted;
}
System.out.println(correct + "/" + testFeatures.length + "= "
+ (correct / (double) testFeatures.length * 100d) + "% Accuracy.");
}
private static void read(String inputFile, List<String> words,
List<Integer> labels) throws IOException, FileNotFoundException {
try (BufferedReader br = new BufferedReader(new FileReader(inputFile))) {
String line;
while ((line = br.readLine()) != null) {
String[] split = line.split("\t");
if (split.length == 0) {
words.add("SENTENCE_BEGIN");
labels.add(indexTag.get("SENTENCE_BEGIN"));
} else {
words.add(split[0]);
String tag = split[1];
if (!indexTag.containsKey(tag)) {
indexTag.put(tag, id++);
}
labels.add(indexTag.get(tag));
}
}
}
}
static class POSExtractor implements SequenceFeatureExtractor<String> {
private static final int SUFFIX_LENGTH = 4;
private static final Pattern punct = Pattern
.compile("[!#%*+;,/<=>?@^_`{|}~]");
@Override
public List<String> computeFeatures(List<String> words, int prevLabel,
int position) {
ArrayList<String> features = Lists.newArrayList();
String word = words.get(position);
features.add("current=" + word);
features.add("prevlabel=" + prevLabel);
if (position > 0) {
features.add("prev=" + words.get(position - 1));
}
if (position < words.size() - 1) {
features.add("next=" + words.get(position + 1));
}
if (word.indexOf('-') != -1) {
features.add("hyphen");
}
if (word.equals("...")) {
features.add("threedots");
} else if (punct.matcher(word).find()) {
features.add("punct");
}
if (word.indexOf('&') != -1) {
features.add("amp");
}
if (word.indexOf('$') != -1) {
features.add("curr");
}
if (word.indexOf('(') != -1) {
features.add("leftbrace");
}
if (word.indexOf(')') != -1) {
features.add("rightbrace");
}
if (word.indexOf('\'') != -1) {
features.add("singlequote");
}
if (word.indexOf('"') != -1) {
features.add("doublequote");
}
String[] prefs = getPrefixes(word);
for (int i = 0; i < prefs.length; i++) {
features.add("pre=" + prefs[i]);
}
String[] suffs = getSuffixes(word);
for (int i = 0; i < suffs.length; i++) {
features.add("suf=" + suffs[i]);
}
return features;
}
protected static String[] getSuffixes(String lex) {
String[] suffs = new String[SUFFIX_LENGTH];
for (int li = 0, ll = SUFFIX_LENGTH; li < ll; li++) {
suffs[li] = lex.substring(Math.max(lex.length() - li - 1, 0));
}
return suffs;
}
protected static String[] getPrefixes(String lex) {
String[] prefs = new String[SUFFIX_LENGTH];
for (int li = 0, ll = SUFFIX_LENGTH; li < ll; li++) {
prefs[li] = lex.substring(0, Math.min(li + 1, lex.length()));
}
return prefs;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment