Created
April 19, 2018 05:05
-
-
Save yi-jiayu/6fe425b8077be3841a60efbdcb310467 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import re | |
def parse_text(text): | |
"""Returns a list of sentence chains extracted from text.""" | |
text = text.strip() | |
sentences = [] | |
for line in text.splitlines(): | |
line = line.strip() | |
line = re.sub('[^\w\s]', '', line) | |
line = line.lower() | |
sentences.append(line) | |
return sentences | |
def calculate_transition_probabilities(sentences): | |
"""Returns a dictionary of transition probabilities where transition_probabilities[bef][aft] | |
represents the probability of transitioning from bef to aft.""" | |
# convert each sentence into a chain by splitting on words and adding the start and end states | |
chains = [] | |
for s in sentences: | |
chain = ['START'] + s.split(' ') + ['END'] | |
chains.append(chain) | |
# find transitions | |
transitions = [] | |
for c in chains: | |
# zipping a sequence with itself offset by one returns overlapping tuples of adjacent items | |
for bef, aft in zip(c, c[1:]): | |
transitions.append((bef, aft)) | |
# count transitions as well as occurrences of each word | |
total_count = {} | |
transition_counts = {} | |
for t in transitions: | |
bef, aft = t | |
# use dict.get to return default values if they key doesn't exist | |
total_count[bef] = total_count.get(bef, 0) + 1 | |
transition_counts[bef] = transition_counts.get(bef, {}) | |
transition_counts[bef][aft] = transition_counts[bef].get(aft, 0) + 1 | |
# calculate transition probabilities | |
transition_probabilities = {} | |
for bef in transition_counts.keys(): | |
transition_probabilities[bef] = {} | |
for aft, count in transition_counts[bef].items(): | |
prob = count / total_count[bef] | |
transition_probabilities[bef][aft] = prob | |
return transition_probabilities | |
def next_state(current_state, transition_probabilities): | |
"""Returns a random possible next state after current_state based on transition_probabilities.""" | |
choices = list(transition_probabilities[current_state].keys()) | |
weights = list(transition_probabilities[current_state].values()) | |
return random.choices(choices, weights=weights, k=1).pop() | |
def generate_chain(transition_probabilities): | |
"""Generates a complete chain from START to END based on transition_probabilities.""" | |
chain = [] | |
curr = 'START' | |
while True: | |
chain.append(curr) | |
if curr == 'END': | |
return chain | |
curr = next_state(curr, transition_probabilities) | |
def main(): | |
text = '''What dessert do you like? | |
I like fruit tarts. | |
I like fruit tarts too. | |
What kind of fruit tart do you like? | |
I like mixed fruit tarts.''' | |
sentences = parse_text(text) | |
transition_probabilities = calculate_transition_probabilities(sentences) | |
print(transition_probabilities) | |
for _ in range(5): | |
print(' '.join(generate_chain(transition_probabilities))) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment