Skip to content

Instantly share code, notes, and snippets.

@magcius
Created August 11, 2009 22:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save magcius/166164 to your computer and use it in GitHub Desktop.
Save magcius/166164 to your computer and use it in GitHub Desktop.
import random
import storage
import config
import data
#===============================================================#
def write_model(mod, f, level = 0):
# TODO: modify
tabs = "\t" * level
for i in mod.iterkeys():
f.write(tabs + "[" + str(mod[i][1]) + "] \"" + i + "\"\n")
write_model(mod[i][0], f, level + 1)
#===============================================================#
class Brain(object):
def __init__(self, conn_str, order = 4):
self.order = order
storage.init_storage(conn_str)
""" temp data """
self.cur_keys = None
def addsymbols(self, symbols):
""" Add the symbols to the brain """
subroots = []
# Clever one-liner, for a list [1, 2, 3, 4, 5]
# and a parameter (in this case 3) that will produce:
# (1, 2, 3), (2, 3, 4), (3, 4, 5)
for chunk in reversed(zip(*(symbols[n:] for n in xrange(self.order+1)))):
node = storage.get_root()
for i, symbol in enumerate(chunk):
if symbol in node:
print "Found node for symbol %r" % symbol
node = node[symbol]
node.count += 1
elif i == len(chunk) - 1: # leaf node
print "Creating leaf node for symbol %r" % symbol
# We already added a subroot node this round, so the second
# from last should give us what we want to link to.
link = None
if len(subroots) >= 2:
link = subroots[-2]
node = storage.LeafNode(symbol, parent=node, link=link)
else:
print "Creating parent node for symbol %r" % symbol
node = storage.ParentNode(symbol, parent=node)
if i == 0: # subroot node
subroots.append(node)
for subroot1, subroot2 in zip(subroots[1:], subroots):
subroot1.link = subroot2
def learn(self, symbols):
""" Add the input to the model. """
if len(symbols) <= self.order:
return
self.addsymbols(symbols)
def get_keywords(self, symbols):
# if the symbol is not in the banned keywords list and not
# punctuation then it's a keyword
for i in symbols:
if i not in banned_keywords \
and i[0] not in config.boundaries["nonwords"]:
wo.append(i)
return wo
def choose_keyword(self, keywords):
"""
Choose a keyword at random that is suitable to use. A keyword is not
suitable to use if it is not both a leaf node and a root node, because
that is what we generate the sentence from.
"""
acceptable = []
for keyword in keywords:
# Make sure we can use the keyword with the models.
if storage.symbol_exists(keyword):
acceptable.append(keyword)
if len(acceptable) == 0:
return None
return random.choice(acceptable)
def child(self, tree):
# TODO: update, comment
count = random.randint(1, tree[1])
tally = 0
for k in tree[0].keys():
if k in self.cur_keys and random() > 0.1:
return k
tally += tree[0][k][1]
if tally >= count:
return k
def seed(self, model, keyword):
""" seed the sentence with words that branch off from the keyword """
# TODO: update
ctx = model[keyword]
for i in range(self.order):
# choose a random child and set it as the new context
newk = self.child(ctx)
ctx = ctx[0][newk]
yield ctx
def babble(self, model, symbols):
# TODO: update
use = symbols[-self.order + 1:]
ctx = model[words[-self.order]]
# walk the tree up to the end node - to get the next word.
for symb in use:
ctx = ctx[0][symb]
words.append(self.child(ctx))
def reply(self, symbols):
# Step 1. Choose a keyword from the Sentence.
keywords = self.get_keywords(symbols)
# change keywords to make them make sense [see config.swap]
for i in xrange(len(keywords)):
if keywords[i] in config.swap:
keywords[i] = config.swap[keywords[i]]
use_keyword = self.choose_keyword(keywords)
if use_keyword == None:
# try 20 times to get a random topic, or give up
for i in xrange(20):
randword = random.choice(self.forward.keys())
if randword in self.backward:
use_keyword = randword
self.cur_keys = set(keywords)
# Step 2. Use the keyword to seed the forward and backward models.
seed_forw = self.seed(self.forward, use_keyword)
seed_back = self.seed(self.backward, use_keyword)
symbols_forw = [use_keyword] + seed_forw
symbols_back = [use_keyword] + seed_back
if config.sentence_term not in seed_forw:
while symbols_forw[-1] != config.sentence_term:
self.babble(self.forward, symbols_forw)
if config.sentence_term not in seed_back:
while symbols_back[-1] != config.sentence_term:
self.babble(self.backward, symbols_back)
symbols_forw = words_forw[1:-1]
symbols_back = words_back[1:-1]
symbols_back.reverse()
return symbols_back + [use_keyword] + symbols_forw
tmp = out.replace("i'", "I'")
#===============================================================#
class GigaHal(object):
def __init__(self, conn_str, order = 4):
self.brain = Brain(conn_str, order)
self.version = config.version
if storage.known_items() < 5:
for i in data.train_data:
self.brain.learn(self.separate_words(i))
storage.session.commit()
def input_no_reply(self, string):
""" learn from input but don't generate a reply """
words = self.separate_words(string)
self.brain.learn(words)
storage.session.commit()
def input_with_reply(self, string):
"""use the input to generate a reply, then learn from the input """
words = self.separate_words(string)
out = self.brain.reply(words)
self.brain.learn(words)
return self.make_pretty(out)
def make_pretty(self, out):
""" Make the output prettier, by capitalizing, etc. """
for i in xrange(len(out)):
if out[i] == "i":
out[i] == "I"
# capitalize word after a period
if out[i].find(".") != -1:s
out[i + 1] = out[i + 1].capitalize()
string = ''.join(tmp).strip().capitalize()
string = string.replace("i'", "I'")
return string
def separate_words(self, input):
""" separate input into words and non-words """
input = input.lower()
words = []
bounds = config.boundaries
# cur_boundary is opposite of the current character set
cur_boundary = "words" if (input[0] in bounds["nonwords"]) else "words"
word = input[0]
for char in input[1:]:
if char in bounds[cur_boundary]:
cur_boundary = "words" if cur_boundary == "nonwords" else "nonwords"
words.append(word)
word = char
else:
word += char
words.append(word)
# If the last symbol isn't punctuation, append a period
if word[0] in bounds["words"]:
words.append(".")
return words
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment