Skip to content

Instantly share code, notes, and snippets.

@rossgoodwin
Created November 20, 2018 14:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rossgoodwin/8df620efa0edc010a426b9987c933dd0 to your computer and use it in GitHub Desktop.
Save rossgoodwin/8df620efa0edc010a426b9987c933dd0 to your computer and use it in GitHub Desktop.
from collections import defaultdict, Counter
from random import random
import sys
def train_char_lm(filename, order=4):
with open(filename, 'r') as infile:
raw_text = infile.read()
lm = defaultdict(Counter)
pad = "~" * order
text = pad + raw_text
for i in xrange(len(text)-order):
history = text[i:i+order]
char = text[i+order]
lm[history][char] += 1
def normalize(c):
s = float(sum(c.values()))
out = []
for char, cnt in c.iteritems():
out.append( (char, cnt/s) )
return out
outlm = {}
for hist, chars in lm.iteritems():
outlm[hist] = normalize(chars)
return outlm
def generate_letter(lm, history, order):
history = history[-order:]
dist = lm[history]
x = random()
for c,v in dist:
x = x - v
if x <= 0:
return c
def generate_text(lm, order, nletters):
history = '~' * order
out = []
for i in xrange(nletters):
c = generate_letter(lm, history, order)
history = history[-order:] + c
out.append(c)
return "".join(out)
lm = train_char_lm("alice_in_wonderland.txt", order=int(sys.argv[1]))
print generate_text(lm, int(sys.argv[1]), int(sys.argv[2]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment