Skip to content

Instantly share code, notes, and snippets.

@scturtle
Created April 13, 2017 06:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save scturtle/a88beff40d27e9b9f25174318b39192b to your computer and use it in GitHub Desktop.
Save scturtle/a88beff40d27e9b9f25174318b39192b to your computer and use it in GitHub Desktop.
Character prediction with LSTM in Tensorflow
from __future__ import print_function
import string
import zipfile
import numpy as np
import tensorflow as tf
import tensorflow.contrib.layers as layers
class BatchGenerator:
def __init__(self, text, vocabulary, batch_size, num_unrollings):
self.text = text
self.batch_size = batch_size
self.vocabulary = vocabulary
self.num_unrollings = num_unrollings
segment_size = len(text) // batch_size
self.cursor = [i * segment_size for i in range(batch_size)]
self.last_batch = self.next_batch()
def next_batch(self):
batch = np.zeros((self.batch_size, len(self.vocabulary)))
for b in range(self.batch_size):
c = self.text[self.cursor[b]]
batch[b, self.vocabulary.index(c)] = 1.0
self.cursor[b] = (self.cursor[b] + 1) % len(self.text)
return batch
def next(self):
batches = [self.last_batch]
for step in range(self.num_unrollings):
batches.append(self.next_batch())
self.last_batch = batches[-1]
return np.asarray(batches)
class LSTM:
def __init__(self, NUM_VOC):
self.inputs = tf.placeholder(tf.float32, (None, None, NUM_VOC))
self.outputs = tf.placeholder(tf.float32, (None, None, NUM_VOC))
cell = tf.contrib.rnn.BasicLSTMCell(512, state_is_tuple=True)
rnn_outputs, rnn_states = tf.nn.dynamic_rnn(
cell, self.inputs, dtype=tf.float32, time_major=True)
raw_outputs = layers.fully_connected(
rnn_outputs, NUM_VOC, activation_fn=None)
self.prediction = tf.nn.softmax(raw_outputs)
self.error = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=raw_outputs, labels=self.outputs))
self.train = tf.train.AdamOptimizer(learning_rate=0.01).minimize(self.error)
def main():
text = tf.compat.as_str(zipfile.ZipFile('text8.zip').read('text8'))
VOCABULARY = ' ' + string.lowercase
NUM_VOC = len(VOCABULARY)
train_batches = BatchGenerator(text, VOCABULARY, 64, 10)
sess = tf.Session()
lstm = LSTM(NUM_VOC)
sess.run(tf.global_variables_initializer())
NUM_EPOCH = 300
NUM_ITER = 100
errs = []
for step in range(1, NUM_EPOCH * NUM_ITER + 1):
batches = train_batches.next()
inputs = batches[:10, :, :]
outputs = batches[1:, :, :]
err, _ = sess.run([lstm.error, lstm.train],
feed_dict={lstm.inputs: inputs,
lstm.outputs: outputs})
errs.append(err)
if step % NUM_ITER == 0:
print('epoch:', step / NUM_ITER, 'avg err:', sum(errs, 0) / len(errs))
errs = []
# random generation
inputs = np.zeros((1, 1, NUM_VOC))
idx = np.random.randint(0, NUM_VOC)
inputs[0, 0, idx] = 1.0
for i in range(20):
pred = sess.run(lstm.prediction,
feed_dict={lstm.inputs: inputs[-10:, :, :]})
idx = np.random.choice(NUM_VOC, p=pred[-1, 0])
newcol = np.zeros((1, 1, NUM_VOC))
newcol[0, 0, idx] = 1.0
inputs = np.concatenate([inputs, newcol], axis=0)
s = ''
for i in range(inputs.shape[0]):
s += VOCABULARY[np.argmax(inputs[i, 0])]
print('gen:', repr(s))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment