Skip to content

Instantly share code, notes, and snippets.

@berlinbrown
Last active December 15, 2015 14:38
Show Gist options
  • Save berlinbrown/5275387 to your computer and use it in GitHub Desktop.
Save berlinbrown/5275387 to your computer and use it in GitHub Desktop.
Use markov chain model to generate text. Train the system and based on the data, return generated random text.
/* Copyright (c) 2013 Berlin Brown (berlin2research.com)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.berlin.crawl.util.text;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Random;
import java.util.TreeMap;
/**
* Use markov chain model to generate text. Train the system and based on the data,
* return generated random text.
*
* <pre>
* Example Input:
* This is going to be a good day for all the powerful people.
*
* this is going to be a bad
*
* </pre>
*
* @author bbrown (berlin.brown at gmail.com)
*/
public class MarkovChainTextGenerator {
// With a more expressive language or more time and better approach
// We could use a combination of words/pairs as the current sequence, what word follows next?
// E.g. in python : ('brown', 'fox'): ['jumps', 'who', 'who']
/** List of string data in each document. */
private List<String> docsData = new ArrayList<String>();
/** Map data structure, position in doc to word map. */
private Map<Integer, Map<Integer, String>> posByWordsData = new TreeMap<Integer, Map<Integer, String>>();
/** Map data structure, position in document with word probability weight. */
private Map<Integer, Map<String, Double>> posWordsFreqData = new TreeMap<Integer, Map<String, Double>>();
/** Map data structure, word and list of next words. */
private Map<String, List<String>> nextWords = new TreeMap<String, List<String>>();
/** Map data structure with a WORD and then then frequency/probability of next possible word. */
private Map<String, Map<String, Double>> nextWordsFreqPerc = new TreeMap<String, Map<String, Double>>();
/** Use navigable map for prob to word. Access using ascending or descending order. Use 'map.ceilingEntry(rval).getValue' */
private final Map<String, NavigableMap<Double, String>> reverseFreqToWord = new TreeMap<String, NavigableMap<Double, String>>();
private Map<String, Map<String, Double>> reverseFreqToTotal = new TreeMap<String, Map<String, Double>>();
private int txtGenStartPos = 0;
private int txtGenMaxWords = 6;
private Random random = new Random(System.currentTimeMillis());
/**
* Use java.io routines to load the document based on the filename.
*
* @param filename
* @return
* @throws Exception
*/
public static String loadDocument(final String filename) throws Exception {
final File f = new File(filename);
final FileInputStream fos = new FileInputStream(f);
final BufferedInputStream s = new BufferedInputStream(fos);
final BufferedReader reader = new BufferedReader(new InputStreamReader(s));
final StringBuffer buf = new StringBuffer();
{
String data = "";
String line = "";
do {
data = reader.readLine();
if (data != null) {
line = data.trim();
if (line.length() > 0) {
buf.append(line);
}
} // End of the if //
} while(data != null);
reader.close();
}
String ss = buf.toString().trim().toLowerCase();
ss = ss.replaceAll("\\.", "");
ss.toString();
return ss;
} // End of the method //
/**
* Tokenize the document in load into position data structure.
*
* Exit method with position of document and words at that position (<code>posByWordsData</code>).
*
* @param doc
* @param docId
*/
public void loadAndTokenizeDocument(final String doc, final int docId) {
// Build a map data structure, position in doc to word-doc map.
final String [] words = doc.split("\\s+");
if (words.length > 0) {
for (int posi = 0; posi < words.length; posi++) {
if (posByWordsData.get(posi) == null) {
// Only one document can have one word at each position.
final Map<Integer, String> docWordMap = new TreeMap<Integer, String>();
posByWordsData.put(posi, docWordMap);
docWordMap.put(docId, words[posi]);
} else {
posByWordsData.get(posi).put(docId, words[posi]);
}
} // End of the for //
}
// Exit method with position of document and words at that position.
} // End of the method //
/**
* Load position/probability data structure.
*
* Exit method with the 'posWordsFreqData' loaded, 'position' in document
* with count/freq at that position.
*/
public void findWordsFreqPos() {
for (final Integer posi : this.posByWordsData.keySet()) {
// Use tmp map of count for each word, used to convert to probability
final Map<String, Integer> mapTmpCtPerPos = new HashMap<String, Integer>();
final Map<Integer, String> docWordMap = this.posByWordsData.get(posi);
for (final Integer docId : docWordMap.keySet()) {
final String word = docWordMap.get(docId);
if (mapTmpCtPerPos.get(word) == null) {
mapTmpCtPerPos.put(word, 1);
} else {
mapTmpCtPerPos.put(word, mapTmpCtPerPos.get(word)+1);
}
} // End of for through the docs
// We should have a count map for words at this position:
// Convert to floats
final double n = docWordMap.size();
for (final Integer docId : docWordMap.keySet()) {
final String word = docWordMap.get(docId);
final double perc = mapTmpCtPerPos.get(word) / n;
if (posWordsFreqData.get(posi) == null) {
final Map<String, Double> wordByFreq = new HashMap<String, Double>();
posWordsFreqData.put(posi, wordByFreq);
wordByFreq.put(word, perc);
} else {
posWordsFreqData.get(posi).put(word, perc);
}
} // End of the for //
}
// Exit method with the 'posWordsFreqData' loaded, 'position' in document with count/freq at that position.
}
/**
* Load the digram word list, the current word against a list of possible next words.
*
* Exit method with 'nextWords' data structure loaded, the current word against
* a list of possible next words.
*/
public void buildDigramWordList() {
// Loop through each position.
for (final Integer posi : this.posByWordsData.keySet()) {
final int nxt = posi + 1;
if (this.posByWordsData.get(nxt) != null && this.posByWordsData.get(posi) != null) {
// Next tokens //
final Map<Integer, String> curPosData = this.posByWordsData.get(posi);
final Map<Integer, String> nextPosData = this.posByWordsData.get(nxt);
Digram diagramSet = null;
if (curPosData != null && nextPosData != null) {
// Build a pair, current to next //
for (Integer docid : curPosData.keySet()) {
final String curword = curPosData.get(docid);
final String nextword = nextPosData.get(docid);
// Now query next
if (curword != null && nextword != null) {
diagramSet = new Digram(curword, nextword);
}
// Build a list of nxt words
if (diagramSet != null) {
if (this.nextWords.get(diagramSet.cur) == null) {
final List<String> listOfWordsNext = new ArrayList<String>();
listOfWordsNext.add(diagramSet.next);
this.nextWords.put(diagramSet.cur, listOfWordsNext);
} else {
this.nextWords.get(diagramSet.cur).add(diagramSet.next);
}
}
} // End of loop through doc //
} // End of if - check cur and next pos
}
} // End of the for //
// Exit method with 'nextWords' data structure loaded, the current word against a list of possible next words.
}
/**
* Load the digram word list, the current word against a list of possible next words with
* the probability that the word will appear next.
*
* Exit with the nextWordsFreqPerc data loaded.
*/
public void buildDigramStats() {
// Core logic, exit method with a map of words to freq/probability/weights.
for (final String keyWordCurrent : this.nextWords.keySet()) {
final List<String> nextWordsTmp = this.nextWords.get(keyWordCurrent);
final Map<String, Integer> mapTmpCtPerNext = new HashMap<String, Integer>();
// Keep tmp map count, used to convert to probability
for (final String keyWordNext : nextWordsTmp) {
if (mapTmpCtPerNext.get(keyWordNext) == null) {
mapTmpCtPerNext.put(keyWordNext, 1);
} else {
// Increment count at each 'current' word.
mapTmpCtPerNext.put(keyWordNext, mapTmpCtPerNext.get(keyWordNext)+1);
}
} // End of for, tmp map with count loaded
final double n = nextWordsTmp.size();
Map<String, Double> mapfreq = null;
if (this.nextWordsFreqPerc.get(keyWordCurrent) == null) {
mapfreq = new HashMap<String, Double>();
this.nextWordsFreqPerc.put(keyWordCurrent, mapfreq);
} else {
mapfreq = this.nextWordsFreqPerc.get(keyWordCurrent);
}
for (final String nxt : mapTmpCtPerNext.keySet()) {
final Integer ct = mapTmpCtPerNext.get(nxt);
if (ct != null) {
final double prob = ct.doubleValue() / n;
mapfreq.put(nxt, prob);
}
}
// Also add nav map perc
this.buildProbNavMap(keyWordCurrent);
} // Loop through each current word
// Exit with the nextWordsFreqPerc data loaded.
}
/**
* Generate text.
*/
public String generateText() {
final Map<Integer, String> simplfirst = this.posByWordsData.get(txtGenStartPos);
final String firstWord = simplfirst.get(0);
final StringBuffer buf = new StringBuffer();
String lastword = firstWord;
buf.append(lastword).append(" ");
for (int i = 0; i < this.txtGenMaxWords; i++) {
lastword = this.next(lastword);
buf.append(lastword).append(" ");
}
return buf.toString();
}
public void printText() {
System.out.println(generateText());
}
/**
* Load the markov chain data structures.
* Load the input documents, invoke <code>loadAndTokenizeDocument</code>
* then findWordsReqPos, buildDigramWordList ...
*
* @throws Exception
*/
public void train() throws Exception {
// These simple input text documents are mostly variants of the string:
// 'This is going to be a good day for all the powerful people'.
final String doc1 = loadDocument("docs/testmark/doc1.txt");
final String doc2 = loadDocument("docs/testmark/doc2.txt");
final String doc3 = loadDocument("docs/testmark/doc3.txt");
final String doc4 = loadDocument("docs/testmark/doc4.txt");
this.docsData.add(doc1);
this.docsData.add(doc2);
this.docsData.add(doc3);
this.docsData.add(doc4);
int id = 0;
for (final String doc : this.docsData) {
this.loadAndTokenizeDocument(doc, id);
id++;
}
this.findWordsFreqPos();
this.buildDigramWordList();
this.buildDigramStats();
} // End of the method //
/**
* Digram/bigram is a key value structure with 'current' word and 'next' word.
*
* @author bbrown
*/
private class Digram {
// Or bigram
// Key value pair for 'next'
private String cur;
private String next;
public Digram(final String k, final String v) {
this.cur = k;
this.next = v;
}
}
private String next(final String cur) {
Map<String, Double> revFreqToTotal = null;
NavigableMap<Double, String> revFreqToWord = null;
if (reverseFreqToWord.get(cur) == null) {
revFreqToWord = new TreeMap<Double, String>();
reverseFreqToWord.put(cur, revFreqToWord);
} else {
revFreqToWord = reverseFreqToWord.get(cur);
}
if (reverseFreqToTotal.get(cur) == null) {
revFreqToTotal = new TreeMap<String, Double>();
reverseFreqToTotal.put(cur, revFreqToTotal);
} else {
revFreqToTotal = reverseFreqToTotal.get(cur);
}
return nextnav(revFreqToWord, revFreqToTotal, cur);
}
private void buildProbNavMap(final String cur) {
final Map<String, Double> mapfreq = this.nextWordsFreqPerc.get(cur);
if (mapfreq != null) {
// Could possibly be null, do not continue
// Loop through current words, find next
Map<String, Double> revFreqToTotal = null;
NavigableMap<Double, String> revFreqToWord = null;
if (reverseFreqToWord.get(cur) == null) {
revFreqToWord = new TreeMap<Double, String>();
reverseFreqToWord.put(cur, revFreqToWord);
} else {
revFreqToWord = reverseFreqToWord.get(cur);
}
if (reverseFreqToTotal.get(cur) == null) {
revFreqToTotal = new TreeMap<String, Double>();
reverseFreqToTotal.put(cur, revFreqToTotal);
} else {
revFreqToTotal = reverseFreqToTotal.get(cur);
}
for (final String nxt : mapfreq.keySet()) {
final Double weight = mapfreq.get(nxt);
if (weight != null) {
addnav(revFreqToWord, revFreqToTotal, weight, cur, nxt);
}
} // End of for, add nav
}
}
private static void addnav(final NavigableMap<Double, String> revFreqToWord, Map<String, Double> revFreqToTotal,
final double weight, final String cur, final String resultar) {
double curtotal = 0;
curtotal = revFreqToTotal.get(cur) == null ? 0 : revFreqToTotal.get(cur);
if (curtotal < 0) {
curtotal = 0;
}
curtotal += weight;
revFreqToTotal.put(cur, curtotal);
revFreqToWord.put(curtotal, resultar);
}
private String nextnav(final NavigableMap<Double, String> revFreqToWord, Map<String, Double> revFreqToTotal, final String cur) {
if (revFreqToWord == null) {
return "";
}
if (revFreqToTotal == null) {
return "";
}
double curtotal = 0;
curtotal = revFreqToTotal.get(cur) == null ? 0 : revFreqToTotal.get(cur);
if (curtotal < 0) {
curtotal = 0;
}
double value = random.nextDouble() * curtotal;
return revFreqToWord.ceilingEntry(value).getValue();
}
/**
* Main entry point for program.
*
* @param args
* @throws Exception
*/
public static void main(final String [] args) throws Exception {
final MarkovChainTextGenerator c = new MarkovChainTextGenerator();
c.train();
c.printText();
} // End of the method //
} // End of the class //
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment