Last active
August 29, 2015 14:08
-
-
Save rdarder/71e0a4f49de26bf39f0e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import islice, chain | |
import random, sys, collections, bisect | |
def index(words, queue): | |
""" | |
Generate a prefix index for an input word sequence. | |
The prefix length is the queue length. | |
The index tracks suffixes and following words, taking frequency into account. | |
The output index will contain, for each suffix, an ascending tuple of | |
(prob, word), with prob in [0,1). If a random number R in [0.1) is searched | |
in this tuple, looking for the rightmost entry such that entry.prob < R, | |
the frequency of choosing entry.word is expected to be the same frequency | |
found in the input with the same suffix. | |
""" | |
map = {} | |
queue.extend(islice(words, 0, len(queue))) | |
for word in words: | |
entry = map.setdefault(tuple(queue), {}) | |
entry[word] = entry[word] + 1 if word in entry else 1 | |
queue.append(word) | |
for prefix in map.keys(): | |
total = float(sum(map[prefix].values())) | |
cum_prob = 0. | |
suffixes = [] | |
for w, c in map[prefix].items(): | |
cum_prob += c / total | |
suffixes.append((cum_prob, w)) | |
map[prefix] = tuple(suffixes) | |
return map | |
def generate(ngrams, buf): | |
""" Generates consecutive words existing in the provided ngrams index. | |
It uses the frequency information in the index to choose the suffix among | |
the valid ones. | |
:param ngrams: ngram index (as provided by index()). | |
:param buf: prefix sized deque, filled with the initial prefix. | |
:return: generator of words | |
""" | |
while tuple(buf) in ngrams: | |
choices = ngrams[tuple(buf)] | |
_, word = choices[bisect.bisect_left(choices, (random.random(), ''))] | |
buf.append(word) | |
yield word | |
if __name__ == '__main__': | |
if len(sys.argv) < 4: | |
sys.exit("{} <N> <input_files...> <output_file>\n" | |
"N: prefix length. 3 for trigrams." | |
.format(sys.argv[0])) | |
n = int(sys.argv[1]) | |
#the input is just all the files concatenated, sanitized and split by spaces. | |
input = chain(*(open(f, 'r').read().translate(None, '_()[]*').split() | |
for f in sys.argv[2:-1])) | |
tri = index(input, collections.deque(maxlen=n - 1)) | |
start = collections.deque(random.choice(list(tri.keys())), n - 1) | |
with open(sys.argv[-1], 'w') as output: | |
output.write(' '.join(start)) | |
for token in generate(tri, start): | |
output.write(' ' + token) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment