Created
October 2, 2014 04:30
-
-
Save Uberi/eefa542a1224511c48ac to your computer and use it in GitHub Desktop.
Run the previous `normalize.py` on FB chat data, then run this script on its output to get sentences generated using a Markov chain.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
path = "normalized_data.json" | |
user = "Keri Warr" | |
import json, sys, re | |
import random | |
from collections import defaultdict | |
data = json.load(open(path, "r")) | |
START, END = 0, 1 | |
# get word token matcher | |
PUNCTUATION = r"[`~@#$%_\\'+\-/]" # punctuation that is a part of text | |
STANDALONE = r"(?:[!.,;()^&\[\]{}|*=<>?]|[dDpP][:8]|:\S)" # standalone characters or emoticons that wouldn't otherwise be captured | |
PATTERN = STANDALONE + r"\S*|https?://\S+|(?:\w|" + PUNCTUATION + r")+" # token pattern | |
matcher = re.compile(PATTERN, re.IGNORECASE) | |
def train(data): | |
# get messages typed by user | |
corpus = [entry[2] for entry in data if entry[1] == user] | |
messages = [] | |
for entry in data: | |
if entry[1] == user: | |
messages.append(START) | |
messages += matcher.findall(entry[2]) | |
messages.append(END) | |
# lowercase all the messages for better results | |
messages = [m.lower() if isinstance(m, str) else m for m in messages] | |
# find word chain counts as a dictionary mapping words to dictionaries mapping words to amount of times they appear after the first word | |
i, total = 0, len(messages) | |
word_chains = defaultdict(lambda: defaultdict(int)) | |
while True: | |
if i >= total - 1: break | |
current, next = messages[i], messages[i + 1] | |
word_chains[current][next] += 1 | |
i += 1 | |
return word_chains | |
def speak(word_chains): | |
# generate a sentence based on probability chains | |
choices = word_chains[START] | |
phrase_list = [] | |
while True: | |
random_choice = random.random() * sum(choices.values()) | |
for current_choice, w in choices.items(): | |
random_choice -= w | |
if random_choice < 0: | |
new_word = current_choice | |
break | |
else: # couldn't find the choice somehow | |
raise Exception("Bad choice!") | |
if current_choice == END: break | |
phrase_list.append(current_choice) | |
choices = word_chains[current_choice] | |
# format the sentence into a human-readable string | |
close_matcher = re.compile("[!.,;)\]}?]", re.IGNORECASE) | |
phrase = "" | |
for i, v in enumerate(phrase_list): | |
if i == 0 or close_matcher.match(v): phrase += v | |
else: phrase += " " + v | |
return phrase | |
#result = "\n".join([str(word) + ":\t" + str(probability) for word, probability in sorted(word_chains[START].items(), key=lambda x: -x[1])]) | |
#print(result.encode(sys.stdout.encoding, errors="replace").decode(sys.stdout.encoding)) | |
chains = train(data) | |
for i in range(5000): | |
print(speak(chains).encode(sys.stdout.encoding, errors="replace").decode(sys.stdout.encoding)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment