Skip to content

Instantly share code, notes, and snippets.

@rdarder
Last active August 29, 2015 14:08
Show Gist options
  • Save rdarder/71e0a4f49de26bf39f0e to your computer and use it in GitHub Desktop.
Save rdarder/71e0a4f49de26bf39f0e to your computer and use it in GitHub Desktop.
from itertools import islice, chain
import random, sys, collections, bisect
def index(words, queue):
"""
Generate a prefix index for an input word sequence.
The prefix length is the queue length.
The index tracks suffixes and following words, taking frequency into account.
The output index will contain, for each suffix, an ascending tuple of
(prob, word), with prob in [0,1). If a random number R in [0.1) is searched
in this tuple, looking for the rightmost entry such that entry.prob < R,
the frequency of choosing entry.word is expected to be the same frequency
found in the input with the same suffix.
"""
map = {}
queue.extend(islice(words, 0, len(queue)))
for word in words:
entry = map.setdefault(tuple(queue), {})
entry[word] = entry[word] + 1 if word in entry else 1
queue.append(word)
for prefix in map.keys():
total = float(sum(map[prefix].values()))
cum_prob = 0.
suffixes = []
for w, c in map[prefix].items():
cum_prob += c / total
suffixes.append((cum_prob, w))
map[prefix] = tuple(suffixes)
return map
def generate(ngrams, buf):
""" Generates consecutive words existing in the provided ngrams index.
It uses the frequency information in the index to choose the suffix among
the valid ones.
:param ngrams: ngram index (as provided by index()).
:param buf: prefix sized deque, filled with the initial prefix.
:return: generator of words
"""
while tuple(buf) in ngrams:
choices = ngrams[tuple(buf)]
_, word = choices[bisect.bisect_left(choices, (random.random(), ''))]
buf.append(word)
yield word
if __name__ == '__main__':
if len(sys.argv) < 4:
sys.exit("{} <N> <input_files...> <output_file>\n"
"N: prefix length. 3 for trigrams."
.format(sys.argv[0]))
n = int(sys.argv[1])
#the input is just all the files concatenated, sanitized and split by spaces.
input = chain(*(open(f, 'r').read().translate(None, '_()[]*').split()
for f in sys.argv[2:-1]))
tri = index(input, collections.deque(maxlen=n - 1))
start = collections.deque(random.choice(list(tri.keys())), n - 1)
with open(sys.argv[-1], 'w') as output:
output.write(' '.join(start))
for token in generate(tri, start):
output.write(' ' + token)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment