Created
August 11, 2009 22:34
-
-
Save magcius/166164 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import storage | |
import config | |
import data | |
#===============================================================# | |
def write_model(mod, f, level = 0): | |
# TODO: modify | |
tabs = "\t" * level | |
for i in mod.iterkeys(): | |
f.write(tabs + "[" + str(mod[i][1]) + "] \"" + i + "\"\n") | |
write_model(mod[i][0], f, level + 1) | |
#===============================================================# | |
class Brain(object): | |
def __init__(self, conn_str, order = 4): | |
self.order = order | |
storage.init_storage(conn_str) | |
""" temp data """ | |
self.cur_keys = None | |
def addsymbols(self, symbols): | |
""" Add the symbols to the brain """ | |
subroots = [] | |
# Clever one-liner, for a list [1, 2, 3, 4, 5] | |
# and a parameter (in this case 3) that will produce: | |
# (1, 2, 3), (2, 3, 4), (3, 4, 5) | |
for chunk in reversed(zip(*(symbols[n:] for n in xrange(self.order+1)))): | |
node = storage.get_root() | |
for i, symbol in enumerate(chunk): | |
if symbol in node: | |
print "Found node for symbol %r" % symbol | |
node = node[symbol] | |
node.count += 1 | |
elif i == len(chunk) - 1: # leaf node | |
print "Creating leaf node for symbol %r" % symbol | |
# We already added a subroot node this round, so the second | |
# from last should give us what we want to link to. | |
link = None | |
if len(subroots) >= 2: | |
link = subroots[-2] | |
node = storage.LeafNode(symbol, parent=node, link=link) | |
else: | |
print "Creating parent node for symbol %r" % symbol | |
node = storage.ParentNode(symbol, parent=node) | |
if i == 0: # subroot node | |
subroots.append(node) | |
for subroot1, subroot2 in zip(subroots[1:], subroots): | |
subroot1.link = subroot2 | |
def learn(self, symbols): | |
""" Add the input to the model. """ | |
if len(symbols) <= self.order: | |
return | |
self.addsymbols(symbols) | |
def get_keywords(self, symbols): | |
# if the symbol is not in the banned keywords list and not | |
# punctuation then it's a keyword | |
for i in symbols: | |
if i not in banned_keywords \ | |
and i[0] not in config.boundaries["nonwords"]: | |
wo.append(i) | |
return wo | |
def choose_keyword(self, keywords): | |
""" | |
Choose a keyword at random that is suitable to use. A keyword is not | |
suitable to use if it is not both a leaf node and a root node, because | |
that is what we generate the sentence from. | |
""" | |
acceptable = [] | |
for keyword in keywords: | |
# Make sure we can use the keyword with the models. | |
if storage.symbol_exists(keyword): | |
acceptable.append(keyword) | |
if len(acceptable) == 0: | |
return None | |
return random.choice(acceptable) | |
def child(self, tree): | |
# TODO: update, comment | |
count = random.randint(1, tree[1]) | |
tally = 0 | |
for k in tree[0].keys(): | |
if k in self.cur_keys and random() > 0.1: | |
return k | |
tally += tree[0][k][1] | |
if tally >= count: | |
return k | |
def seed(self, model, keyword): | |
""" seed the sentence with words that branch off from the keyword """ | |
# TODO: update | |
ctx = model[keyword] | |
for i in range(self.order): | |
# choose a random child and set it as the new context | |
newk = self.child(ctx) | |
ctx = ctx[0][newk] | |
yield ctx | |
def babble(self, model, symbols): | |
# TODO: update | |
use = symbols[-self.order + 1:] | |
ctx = model[words[-self.order]] | |
# walk the tree up to the end node - to get the next word. | |
for symb in use: | |
ctx = ctx[0][symb] | |
words.append(self.child(ctx)) | |
def reply(self, symbols): | |
# Step 1. Choose a keyword from the Sentence. | |
keywords = self.get_keywords(symbols) | |
# change keywords to make them make sense [see config.swap] | |
for i in xrange(len(keywords)): | |
if keywords[i] in config.swap: | |
keywords[i] = config.swap[keywords[i]] | |
use_keyword = self.choose_keyword(keywords) | |
if use_keyword == None: | |
# try 20 times to get a random topic, or give up | |
for i in xrange(20): | |
randword = random.choice(self.forward.keys()) | |
if randword in self.backward: | |
use_keyword = randword | |
self.cur_keys = set(keywords) | |
# Step 2. Use the keyword to seed the forward and backward models. | |
seed_forw = self.seed(self.forward, use_keyword) | |
seed_back = self.seed(self.backward, use_keyword) | |
symbols_forw = [use_keyword] + seed_forw | |
symbols_back = [use_keyword] + seed_back | |
if config.sentence_term not in seed_forw: | |
while symbols_forw[-1] != config.sentence_term: | |
self.babble(self.forward, symbols_forw) | |
if config.sentence_term not in seed_back: | |
while symbols_back[-1] != config.sentence_term: | |
self.babble(self.backward, symbols_back) | |
symbols_forw = words_forw[1:-1] | |
symbols_back = words_back[1:-1] | |
symbols_back.reverse() | |
return symbols_back + [use_keyword] + symbols_forw | |
tmp = out.replace("i'", "I'") | |
#===============================================================# | |
class GigaHal(object): | |
def __init__(self, conn_str, order = 4): | |
self.brain = Brain(conn_str, order) | |
self.version = config.version | |
if storage.known_items() < 5: | |
for i in data.train_data: | |
self.brain.learn(self.separate_words(i)) | |
storage.session.commit() | |
def input_no_reply(self, string): | |
""" learn from input but don't generate a reply """ | |
words = self.separate_words(string) | |
self.brain.learn(words) | |
storage.session.commit() | |
def input_with_reply(self, string): | |
"""use the input to generate a reply, then learn from the input """ | |
words = self.separate_words(string) | |
out = self.brain.reply(words) | |
self.brain.learn(words) | |
return self.make_pretty(out) | |
def make_pretty(self, out): | |
""" Make the output prettier, by capitalizing, etc. """ | |
for i in xrange(len(out)): | |
if out[i] == "i": | |
out[i] == "I" | |
# capitalize word after a period | |
if out[i].find(".") != -1:s | |
out[i + 1] = out[i + 1].capitalize() | |
string = ''.join(tmp).strip().capitalize() | |
string = string.replace("i'", "I'") | |
return string | |
def separate_words(self, input): | |
""" separate input into words and non-words """ | |
input = input.lower() | |
words = [] | |
bounds = config.boundaries | |
# cur_boundary is opposite of the current character set | |
cur_boundary = "words" if (input[0] in bounds["nonwords"]) else "words" | |
word = input[0] | |
for char in input[1:]: | |
if char in bounds[cur_boundary]: | |
cur_boundary = "words" if cur_boundary == "nonwords" else "nonwords" | |
words.append(word) | |
word = char | |
else: | |
word += char | |
words.append(word) | |
# If the last symbol isn't punctuation, append a period | |
if word[0] in bounds["words"]: | |
words.append(".") | |
return words | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment