Skip to content

Instantly share code, notes, and snippets.

@mpitid
Created February 23, 2012 20:12
Show Gist options
  • Save mpitid/1894835 to your computer and use it in GitHub Desktop.
Save mpitid/1894835 to your computer and use it in GitHub Desktop.
Implementation of a simple training program in Python, Java and Scala
// vim: set ts=4 sw=4 et:
import java.io.FileReader;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.Iterator;
import java.util.Scanner;
import java.util.HashMap;
import java.util.Map;
public class JTrain {
/** Keep track of feature frequencies. */
static HashMap<String, int[]> counter = new HashMap<String, int[]>();
public static void main(String[] args) throws IOException {
Scanner train = null;
try {
train = new Scanner(
new BufferedReader(
new FileReader(args[0])));
for (String token: Iter.string(train))
if (!token.isEmpty()) {
boolean upper = Character.isUpperCase(token.charAt(0));
for (String feature: features(token.toLowerCase()))
count(feature, upper);
}
} finally {
if (train != null)
train.close();
}
for (Map.Entry<String, int[]> entry: counter.entrySet())
System.out.printf("%s %d %d\n",
entry.getKey(),
entry.getValue()[0],
entry.getValue()[1]);
}
public static void count(String key, boolean upper) {
int index = upper ? 1 : 0;
if (!counter.containsKey(key)) {
int[] v = {0, 0};
counter.put(key, v);
}
//counter.get(key)[index]++;
counter.get(key)[index] = roughCount(counter.get(key)[index]);
}
public static String[] features(String word) {
String[] fs = {
word,
substring(word, 0, 2),
substring(word, 0, 3),
substring(word, -2, -1),
substring(word, -3, -1),
};
return fs;
}
/**
* Substring indexing similar to Python slices.
*/
public static String substring(String s, int i, int j) {
int n = s.length();
i = i < 0 ? n + i: i;
j = j < 0 ? n + 1 + j : j;
return s.substring(Math.max(i, 0), Math.max(0, Math.min(j, n)));
}
/**
* Probabilistic counting, catch every log(N) occurrences.
*/
public static int roughCount(int value) {
double n = Math.random();
return (n < (1.0 / (1 << value))) ? value + 1 : value;
}
/**
* Iterable wrapper around an Iterator object.
*/
static class Iter<T> implements Iterable<T> {
private Iterator<T> t = null;
Iter (Iterator<T> t) {
assert(t != null);
this.t = t;
}
public Iterator<T> iterator() { return t; }
public static Iterable<String> string(Iterator<String> t) { return new Iter<String>(t); }
}
}
object STrain {
val rnd = new util.Random()
val map = new collection.mutable.HashMap[String, Array[Int]]
def main(args: Array[String]) {
if (args.length > 0) {
// Build the map of feature frequencies.
for {
line <- io.Source.fromFile(args(0)).getLines
token <- line.split("[ \t]+") if !token.isEmpty
} {
val upper = token(0).isUpper
for (feature <- features(token.toLowerCase))
count(feature, upper)
}
// Print the resulting map.
for ((key, value) <- map)
println("%s %d %d".format(key, value(0), value(1)))
} else
Console.err.println("usage: STrain <filename>")
}
def count(key: String, upper: Boolean) = {
val idx = if (upper) 1 else 0
if (!map.contains(key))
map += (key -> Array(0, 0))
map(key)(idx) = roughCount(map(key)(idx))
}
def roughCount(value: Int) = {
if (rnd.nextDouble() < 1.0 / (1 << value)) value + 1 else value
}
def features(word: String) = Array(
word,
slice(word, 0, 2),
slice(word, 0, 3),
slice(word,-2,-1),
slice(word,-3,-1)
)
def slice(s: String, from: Int, upto: Int) = {
val n = s.length
val i = from + (if (from < 0) n else 0)
val j = upto + (if (upto < 0) n + 1 else 0)
s.slice(0 max i, 0 max (n min j))
}
}
#!/usr/bin/env python
import sys, random
from collections import defaultdict
def main(args):
if len(args) < 2:
print "usage: %s <filename>" % args[0]
return 1
counts = defaultdict(lambda: [0, 0])
def count(f, i):
counts[f][i] = rough_count(counts[f][i])
with open(args[1]) as f:
for token in (t for line in f for t in line.split() if t):
upper = token[0].isupper()
for feature in features(token.lower()):
count(feature, upper)
for k, (l, u) in counts.iteritems():
print "%s %d %d" % (k, l, u)
return 0
def features(word):
return word, word[-3:], word[:3], word[-2:], word[:2]
def rough_count(val):
return val + 1 if random.random() < 1.0 / (1 << val)\
else val
if __name__ == '__main__':
sys.exit(main(sys.argv))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment