Skip to content

Instantly share code, notes, and snippets.

@jamiesanson
Created November 14, 2017 21:11
Show Gist options
  • Save jamiesanson/5f1450ec6d48b35992075196400ec177 to your computer and use it in GitHub Desktop.
Save jamiesanson/5f1450ec6d48b35992075196400ec177 to your computer and use it in GitHub Desktop.
LSTM text-gen implementation adapted from Keras Example
'''
Script to generate text from Chopra's writings.
At least 20 epochs are required before the generated text
starts sounding coherent.
It is recommended to run this script on GPU, as recurrent
networks are quite computationally intensive.
If you try this script on new data, make sure your corpus
has at least ~100k characters. ~1M is better.
Note: The concatenation of "How To Know God" and
"The Seven Spiritual laws of Success"
This script was adapted from this script found in the Keras
repo: https://github.com/fchollet/keras/blob/master/examples/lstm_text_generation.py
'''
from __future__ import print_function
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.layers import LSTM, CuDNNLSTM
from keras.optimizers import RMSprop
import numpy as np
import random
import sys
# Load in our corpus
text = open('chopra.txt', encoding="utf8").read().lower()
print('corpus length:', len(text))
chars = sorted(list(set(text)))
print('total chars:', len(chars))
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))
# Cut the text in semi-redundant sequences of maxlen characters.
# This is done to form batches to learn off.
maxlen = 80
step = 3
sentences = []
next_chars = []
for i in range(0, len(text) - maxlen, step):
sentences.append(text[i: i + maxlen])
next_chars.append(text[i + maxlen])
print('nb sequences:', len(sentences))
print('Vectorization...')
# Here we form vectors out of individual sentences. These are binary based
x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)
for i, sentence in enumerate(sentences):
for t, char in enumerate(sentence):
x[i, t, char_indices[char]] = 1
y[i, char_indices[next_chars[i]]] = 1
# build the model: a single LSTM backed by CUDA for optimisation
print('Build model...')
model = Sequential()
# The input shape here refers to having a vectorised sentence with length maxlen, with each element being a vector of length
# len(chars), where this refers to the number of different characters are present in the sample
model.add(CuDNNLSTM(256, input_shape=(maxlen, len(chars))))
model.add(Dense(len(chars)))
model.add(Activation('softmax'))
optimizer = RMSprop(lr=0.01)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)
def sample(preds, temperature=1.0):
# helper function to sample an index from a probability array
preds = np.asarray(preds).astype('float64')
preds = np.log(preds) / temperature
exp_preds = np.exp(preds)
preds = exp_preds / np.sum(exp_preds)
probas = np.random.multinomial(1, preds, 1)
return np.argmax(probas)
for iteration in range(1, 50):
print()
print('-' * 50)
print('Iteration', iteration)
# Only fit for one epoch each time such that we can then generate text at each epoch
model.fit(x, y,
batch_size=256,
epochs=1)
# Seeding output prediction
start_index = random.randint(0, len(text) - maxlen - 1)
for diversity in [0.2, 0.5, 1.0, 1.2]:
print()
print('----- diversity:', diversity)
generated = ''
sentence = text[start_index: start_index + maxlen]
generated += sentence
print('----- Generating with seed: "' + sentence + '"')
sys.stdout.write(generated)
# Take 400 characters as output
for i in range(400):
# Vectorise input sentence
x_pred = np.zeros((1, maxlen, len(chars)))
for t, char in enumerate(sentence):
x_pred[0, t, char_indices[char]] = 1.
# Use vectorised sentence to predict
preds = model.predict(x_pred, verbose=0)[0]
next_index = sample(preds, diversity)
next_char = indices_char[next_index]
# Concatenate to generated, the slide along our seeding sentence
generated += next_char
sentence = sentence[1:] + next_char
sys.stdout.write(next_char)
sys.stdout.flush()
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment