Skip to content

Instantly share code, notes, and snippets.

@rdarder
Created October 23, 2014 21:53
Show Gist options
  • Save rdarder/889e472b5ddb41971970 to your computer and use it in GitHub Desktop.
Save rdarder/889e472b5ddb41971970 to your computer and use it in GitHub Desktop.
Trigram text generator
import bisect
import collections
import random
import ply.lex as lex
import sys
import time
tokens = (
'NUMBER',
'WORD',
'PUNCT',
'STOP',
'NEWLINE',
'DECORATOR'
)
t_WORD = r"[a-zA-Z']+"
t_NUMBER = r'[0-9]+'
t_PUNCT = r'[,;:]'
t_STOP = r'[\.!?]+'
t_NEWLINE = r'\r?\n'
t_ignore = '\t -"_[]()*/#%$@' + "'"
def t_error(t):
t.lexer.skip(1)
def tokenize(input):
data = input.read()
lexer = lex.lex()
lexer.input(data)
while True:
tok = lexer.token()
if tok:
yield tok
else:
return
class SetWeightedMultiMap(object):
def __init__(self):
self.map = {}
def __setitem__(self, key, value):
values = self.map.get(key)
if values is None:
self.map[key] = {value: 1}
else:
count = values.setdefault(value, 0)
values[value] = count + 1
def __getitem__(self, key):
return self.map[key]
def freeze(self):
for key in self.map.keys():
prefix = self.map[key]
key_count = float(sum(count for count in prefix.values()))
weights_in_order = sorted(
(value_count / key_count, value) for value, value_count in
prefix.items())
cum_weight = 0.
cum_weighted_values = []
for weight, value in weights_in_order:
cum_weighted_values.append((weight + cum_weight, value))
cum_weight += weight
assert cum_weight > 0.999
self.map[key] = cum_weighted_values
class TrigramIndex(object):
def __init__(self):
self.trigrams = SetWeightedMultiMap()
def put(self, w1, w2, w3):
self.trigrams[(w1, w2)] = w3
def get(self, w1, w2):
weighed_suffixes = self.trigrams[(w1, w2)]
if weighed_suffixes is None:
return None
roulette = random.random()
position = bisect.bisect_right(weighed_suffixes, (roulette, ''))
return weighed_suffixes[position][1]
def freeze(self):
self.trigrams.freeze()
class TrigramBuffer(object):
def __init__(self, indexer):
self.indexer = indexer
self.deque = collections.deque(maxlen=3)
def put(self, token):
self.deque.append(token.value)
if len(self.deque) == 3:
self.indexer.put(*self.deque)
class SentenceStarts(object):
def __init__(self):
self.starts = set()
self.start_offset = 0
self.w1 = None
def freeze(self):
self.starts = tuple(self.starts)
def start(self):
return random.choice(self.starts)
def put(self, token):
if token.type == 'STOP':
self.start_offset = 0
return
elif self.start_offset > 1:
return
elif self.start_offset == 0:
if token.value.istitle():
self.w1 = token.value
self.start_offset = 1
else:
self.start_offset = 2
elif self.start_offset == 1:
if not token.value.isupper():
self.starts.add((self.w1, token.value))
self.start_offset = 2
class TokenMultiplexer(object):
def __init__(self, *sinks):
self.sinks = sinks
def put(self, token):
for sink in self.sinks:
sink.put(token)
class BookSanitizer(object):
def __init__(self, sink):
self.buf = []
self.words = 0
self.tokens = 0
self.sink = sink
def put(self, token):
self.tokens += 1
if token.type in {'WORD', 'NUMBER'} and len(token.value) > 1:
self.words += 1
self.buf.append(token)
elif token.type == 'PUNCT':
self.buf.append(token)
elif token.type == 'STOP':
if self.words > 2:
self.buf.append(token)
self.flush()
else:
self.empty()
def flush(self):
for token in self.buf:
self.sink.put(token)
self.empty()
def empty(self):
self.buf = []
self.words = self.tokens = 0
class SentenceEmitter(object):
puncts = set("?!.,:;")
stops = set('?!.')
def __init__(self, index):
self.index = index
def get_sentences(self, start):
w1, w2 = start
w3 = ''
buf = [w1, ' ', w2]
while True:
w3 = self.index.get(w1, w2)
if w3 not in self.puncts:
buf.append(' ')
buf.append(w3)
w1, w2 = w2, w3
if w3 in self.stops:
yield ''.join(buf)
buf = []
class ParagraphEmitter(object):
def __init__(self, starts, sentence_emitter, mean_sentences=6,
stddev_sentences=2):
self.mean = mean_sentences
self.stddev = stddev_sentences
self.sentence_emitter = sentence_emitter
self.starts = starts
def get_paragraphs(self):
while True:
yield self.get_paragraph()
def get_paragraph(self):
buf = []
for sentence, _ in zip(
self.sentence_emitter.get_sentences(self.starts.start()),
range(int(random.gauss(self.mean, self.stddev)))):
buf.append(sentence)
return ''.join(buf) + '\n\n'
def main(size=10000):
random.seed(time.time())
index = TrigramIndex()
starts = SentenceStarts()
buffer = TrigramBuffer(index)
mplex = TokenMultiplexer(starts, buffer)
sanitizer = BookSanitizer(mplex)
with open(sys.argv[1]) as input:
for token in tokenize(input):
sanitizer.put(token)
index.freeze()
starts.freeze()
sentence_emitter = SentenceEmitter(index)
paragraph_emitter = ParagraphEmitter(starts, sentence_emitter, 6, 2)
chars_emitted = 0
with open(sys.argv[2], 'w') as output:
for paragraph in paragraph_emitter.get_paragraphs():
chars_emitted += len(paragraph)
output.write(paragraph)
if chars_emitted > size:
break
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment