Skip to content

Instantly share code, notes, and snippets.

@ltrgoddard
Created March 5, 2016 16:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ltrgoddard/78b4ad7bb8df8b16b00d to your computer and use it in GitHub Desktop.
Save ltrgoddard/78b4ad7bb8df8b16b00d to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# This is a simple script to set up a Twitter 'bot' based on a character-level recurrent neural network. Clone sherjilozair's
# char-rnn-tensorflow (https://github.com/sherjilozair/char-rnn-tensorflow) and train it on the material of your choice.
# Then drop this script into the main directory, create a Twitter account and Twitter app for the bot and enter the
# relevant authentication information at the commented points below. Run this script and whenever somebody
# @mentions the bot it will reply with a sample from your neural network.
# Louis Goddard <louisgoddard@gmail.com>
import numpy as np
import tensorflow as tf
import argparse
import time
import os
import re
import cPickle
from utils import TextLoader
from model import Model
from twython import Twython
APP_KEY = '' # consumer key
APP_SECRET = '' # consumer secret
OAUTH_TOKEN = '' # access token
OAUTH_TOKEN_SECRET = '' # access token secret
twitter = Twython(APP_KEY, APP_SECRET, OAUTH_TOKEN, OAUTH_TOKEN_SECRET)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--save_dir', type=str, default='save',
help='model directory to store checkpointed models')
parser.add_argument('-n', type=int, default=140,
help='number of characters to sample')
parser.add_argument('--prime', type=str, default=' ',
help='prime text')
args = parser.parse_args()
sample(args)
def sample(args):
with open(os.path.join(args.save_dir, 'config.pkl')) as f:
saved_args = cPickle.load(f)
with open(os.path.join(args.save_dir, 'chars_vocab.pkl')) as f:
chars, vocab = cPickle.load(f)
model = Model(saved_args, True)
with tf.Session() as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(tf.all_variables())
ckpt = tf.train.get_checkpoint_state(args.save_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
latest = twitter.get_home_timeline(count = 1)
ident = latest[0]['id']
while True:
mentions = twitter.get_mentions_timeline(contributor_details = True, since_id = ident)
for mention in mentions:
ident = mention['id']
target = mention['user']['screen_name']
incoming = re.sub('[^A-Za-z0-9]+', '', mention['text'])
output = str(model.sample(sess, chars, vocab, len(mention['text'])+(140-len(target)), incoming))
twitter.update_status(status = '@' + target + ' ' + output[len(incoming):139], in_reply_to_status_id = ident)
time.sleep(60)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment