Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Markov Chain implementation
import gnu.trove.map.TCharDoubleMap;
import gnu.trove.map.hash.TCharDoubleHashMap;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;
public class MarkovChain implements Serializable {
private static final long serialVersionUID = 986958034001823764L;
private final Pattern splitPattern;
private final int nGramSize;
private final Map<String, TCharDoubleMap> chain;
private char[] acceptedChars;
private char[] filteredChars;
private double unknownProbability;
public MarkovChain() {
this("", 2);
}
public MarkovChain(String splitRegex, int nGramSize) {
if (splitRegex == null) {
splitPattern = null;
} else {
splitPattern = Pattern.compile(splitRegex);
}
this.nGramSize = nGramSize;
chain = new HashMap<>();
acceptedChars = new char[0];
filteredChars = new char[0];
}
public MarkovChain(int nGramSize) {
this(null, nGramSize);
}
public void setAcceptedChars(char[] acceptedChars) {
this.acceptedChars = acceptedChars;
}
public void setFilteredChars(char[] filteredChars) {
this.filteredChars = filteredChars;
}
public void update(String input) {
if (splitPattern == null) {
updateToken(input);
} else {
String[] tokens = splitPattern.split(input);
update(tokens);
}
}
public void update(String[] tokens) {
for (String token : tokens) {
updateToken(token);
}
}
private void updateToken(String token) {
for (int i = 0; i < token.length() - nGramSize + 1; i++) {
String nGram = token.substring(i, i + nGramSize);
TCharDoubleMap unitToCount = chain.get(nGram);
if (unitToCount == null) {
unitToCount = new TCharDoubleHashMap();
chain.put(nGram, unitToCount);
}
char nextUnit;
if (i + nGramSize + 1 > token.length()) {
nextUnit = '\0';
} else {
nextUnit = token.charAt(i + nGramSize);
}
unitToCount.adjustOrPutValue(nextUnit, 1, 1);
}
}
public void finish() {
double minLogProbability = Double.MAX_VALUE;
for (TCharDoubleMap unitToCount : chain.values()) {
int countSum = 0;
for (double value : unitToCount.values()) {
countSum += value;
}
for (char unit : unitToCount.keys()) {
double count = unitToCount.get(unit);
double logProbability = Math.log(count / countSum);
if (logProbability < minLogProbability) {
minLogProbability = logProbability;
}
unitToCount.put(unit, logProbability);
}
}
// Use this probability when we see a new ngram -> unit pair
// Without this, the probability might be 0 for an input
unknownProbability = Math.min(minLogProbability, Math.log(0.5));
}
/**
*
* @param input
* @return probability or nan if tokens are empty
*/
public double getProbability(String input) {
if (splitPattern == null) {
return getProbabilityToken(input);
} else {
String[] tokens = splitPattern.split(input);
return getProbability(tokens);
}
}
public double getProbability(String[] tokens) {
double probability = 0;
for (String token : tokens) {
probability += getProbabilityToken(token);
}
return probability / tokens.length;
}
private double getProbabilityToken(String token) {
double logProbabilitySum = 0;
int transitionCount = 0;
for (int i = 0; i < token.length() - nGramSize + 1; i++) {
String nGram = token.substring(i, i + nGramSize);
char nextUnit;
if (i + nGramSize + 1 > token.length()) {
nextUnit = '\0';
} else {
nextUnit = token.charAt(i + nGramSize);
}
double tokenProbability = unknownProbability;
TCharDoubleMap unitToCount = chain.get(nGram);
if (unitToCount != null && unitToCount.containsKey(nextUnit)) {
tokenProbability = unitToCount.get(nextUnit);
}
logProbabilitySum += tokenProbability;
transitionCount += 1;
}
transitionCount = Math.max(transitionCount, 1);
return Math.exp(logProbabilitySum / transitionCount);
}
public String asString() {
StringBuilder sb = new StringBuilder();
for (Map.Entry<String, TCharDoubleMap> entry : chain.entrySet()) {
String nGram = entry.getKey();
sb.append('"').append(nGram).append("\":{");
TCharDoubleMap unitToProbability = entry.getValue();
for (char unit : unitToProbability.keys()) {
double count = unitToProbability.get(unit);
sb.append('\'').append(unit).append("'=").append(count).append(", ");
}
if (unitToProbability.size() > 0) {
sb.setLength(sb.length() - 2);
}
sb.append("}\n");
}
if (sb.length() > 1) {
sb.setLength(sb.length() - 1);
}
return sb.toString();
}
public double getUnknownProbability() {
return unknownProbability;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment