Skip to content

Instantly share code, notes, and snippets.

@ben-yu
Created April 7, 2015 20:48
Show Gist options
  • Save ben-yu/919e843ac4df8d0fccee to your computer and use it in GitHub Desktop.
Save ben-yu/919e843ac4df8d0fccee to your computer and use it in GitHub Desktop.
import nltk
import pickle
import random
import string
class MarkovModel(object):
def __init__(self, model_path=None):
if model_path:
pkl_file = open(model_path, 'rb')
self.model = pickle.load(pkl_file)
else:
self.model = dict()
def train(self, model_path=None):
review_file = open('reviews.csv', "rb")
tokenizer = nltk.tokenize.RegexpTokenizer(r'\w+|[^\w\s]+')
tokenized_content = tokenizer.tokenize(review_file.read())
for w1, w2, w3 in self.triplets(tokenized_content):
key = (w1, w2)
if key in self.model:
self.model[key].append(w3)
else:
self.model[key] = [w3]
if model_path:
output = open(model_path, 'wb')
pickle.dump(self.model, output)
output.close()
def triplets(self, words):
"""Generates triplets from a list of words"""
for i in range(len(words) - 2):
yield (words[i], words[i + 1], words[i + 2])
def join_tokens(self, a, b):
if b in string.punctuation or a.endswith('(') or a.endswith('\''):
return a + b
else:
return a + " " + b
def generate_tweet(self):
w1, w2 = random.choice(self.model.keys())
gen_words = []
tweet_length = 0
while tweet_length <= 100:
gen_words.append(w1)
tweet_length += len(w1) + 1
w1, w2 = w2, random.choice(self.model[(w1, w2)])
nouns = filter(lambda x: x[1] == 'NN', nltk.pos_tag(gen_words))
if nouns:
gen_words.append('#' + random.choice(nouns)[0])
gen_words.append('#GameReview')
return reduce(self.join_tokens, gen_words)
if __name__ == '__main__':
#MarkovModel().train(model_path='reviews.pkl')
print MarkovModel(model_path='reviews.pkl').generate_tweet()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment