Skip to content

Instantly share code, notes, and snippets.

Created Mar 5, 2016
What would you like to do?
#!/usr/bin/env python
# This is a simple script to set up a Twitter 'bot' based on a character-level recurrent neural network. Clone sherjilozair's
# char-rnn-tensorflow ( and train it on the material of your choice.
# Then drop this script into the main directory, create a Twitter account and Twitter app for the bot and enter the
# relevant authentication information at the commented points below. Run this script and whenever somebody
# @mentions the bot it will reply with a sample from your neural network.
# Louis Goddard <>
import numpy as np
import tensorflow as tf
import argparse
import time
import os
import re
import cPickle
from utils import TextLoader
from model import Model
from twython import Twython
APP_KEY = '' # consumer key
APP_SECRET = '' # consumer secret
OAUTH_TOKEN = '' # access token
OAUTH_TOKEN_SECRET = '' # access token secret
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--save_dir', type=str, default='save',
help='model directory to store checkpointed models')
parser.add_argument('-n', type=int, default=140,
help='number of characters to sample')
parser.add_argument('--prime', type=str, default=' ',
help='prime text')
args = parser.parse_args()
def sample(args):
with open(os.path.join(args.save_dir, 'config.pkl')) as f:
saved_args = cPickle.load(f)
with open(os.path.join(args.save_dir, 'chars_vocab.pkl')) as f:
chars, vocab = cPickle.load(f)
model = Model(saved_args, True)
with tf.Session() as sess:
saver = tf.train.Saver(tf.all_variables())
ckpt = tf.train.get_checkpoint_state(args.save_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
latest = twitter.get_home_timeline(count = 1)
ident = latest[0]['id']
while True:
mentions = twitter.get_mentions_timeline(contributor_details = True, since_id = ident)
for mention in mentions:
ident = mention['id']
target = mention['user']['screen_name']
incoming = re.sub('[^A-Za-z0-9]+', '', mention['text'])
output = str(model.sample(sess, chars, vocab, len(mention['text'])+(140-len(target)), incoming))
twitter.update_status(status = '@' + target + ' ' + output[len(incoming):139], in_reply_to_status_id = ident)
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment