A character level LSTM for text generation, based on Lasagne's text generation recipe
import numpy as np
import lasagne as L
import theano as T
import random
import sys
def info(info_string):
print '[INFO] {}'.format(info_string)
def build_network(batch_size, seq_length, vocabulary_size, num_hidden, learning_r):
# The input shape is (batch size, sequence length, number of characters)
# By passing `None` for the batch size, the network will accept batches of any size larger than 0
input_shape = (None, seq_length, vocabulary_size)
l_in = L.layers.InputLayer(shape=input_shape)
l_lstm = L.layers.LSTMLayer(l_in, num_hidden)
# We need to obtain the batch size from the input layers, to reshape to the correct dimensions
batch_size, _, _ = l_in.input_var.shape
# The reshape layer turns the 3D matrix in the LSTM layers to a 2D matrix for the dense layer,
# by merging the first two dimensions
l_shape = L.layers.ReshapeLayer(l_lstm, (batch_size * seq_length, num_hidden))
l_out = L.layers.DenseLayer(l_shape, num_units=vocabulary_size, nonlinearity=L.nonlinearities.softmax)
target_values = T.tensor.imatrix('target_output')
network_output = L.layers.get_output(l_out)
loss = T.tensor.nnet.categorical_crossentropy(network_output, target_values).mean()
all_params = L.layers.get_all_params(l_out, trainable=True)
updates = L.updates.adam(loss, all_params, learning_r)
train = T.function([l_in.input_var, target_values], loss, updates=updates, allow_input_downcast=True)
get_output = T.function([l_in.input_var], network_output[-1, :], allow_input_downcast=True)
return train, get_output
def get_batch(data, start, batch_size, seq_length, vocabulary_size, char_to_ix, calc_Y=True):
# X contains a training batch
X = np.zeros((batch_size, seq_length, vocabulary_size))
# Y contains the targets for each time step in X (i.e. it is X shifted by one)
Y = np.zeros((batch_size, seq_length, vocabulary_size))
for sequence in xrange(batch_size):
for character in range(seq_length):
index = start + sequence + character
X[sequence, character, char_to_ix[data[index]]] = 1.0
# When sampling, we don't need Y
if calc_Y:
Y[sequence, character, char_to_ix[data[index + 1]]] = 1.0
# Y must be reshaped to match the output of the network
return X, Y.reshape((batch_size * seq_length, vocabulary_size))
def sample(seed, num_characters, network_output_function, seq_length, vocabulary_size, char_to_ix, ix_to_char):
assert(len(seed) >= seq_length)
sample_ix = []
# Encode the seed
X, _ = get_batch(seed, len(seed) - seq_length, 1, seq_length, vocabulary_size, char_to_ix, calc_Y=False)
for i in range(num_characters):
ix = np.random.choice(range(vocabulary_size), p=network_output_function(X).ravel())
# Shift X by one to the right and add our newly generated character to the end
X[0, 0:seq_length-1, :] = X[:, 1:, :]
X[0, seq_length-1, :] = 0.0
X[0, seq_length-1, sample_ix[-1]] = 1.0
generated = '[' + seed + ']' + ''.join(ix_to_char[ix] for ix in sample_ix)
print '----\n{}\n----'.format(generated)
def main():
batch_size = 10
seq_length = 25
num_hidden = 50
learning_r = 0.01
info('Loading text')
in_text = open('shakespeare.txt', 'r').read().decode("utf-8-sig").encode("utf-8")
seed = "That, poor contempt, or claim'd thou slept so "
in_text_size = len(in_text)
vocabulary = list(set(in_text))
vocabulary_size = len(vocabulary)
char_to_ix = { ch:i for i,ch in enumerate(vocabulary) }
ix_to_char = { i:ch for i,ch in enumerate(vocabulary) }
# Set a seed for the RNG for reproducible results
info('Building network')
train, get_output = build_network(batch_size, seq_length, vocabulary_size, num_hidden, learning_r)
epoch = 0
try: # except KeyboardInterrupt: press CTRL+C to stop
while True:
iterations = 0
loss = 0.0
batches_per_epoch = in_text_size / batch_size
while iterations < batches_per_epoch:
# Select a random start position for the next batch
start = random.randint(0, in_text_size - (batch_size * seq_length) - 2)
# Encode the next batch of sequences
X, Y = get_batch(in_text, start, batch_size, seq_length, vocabulary_size, char_to_ix)
loss += train(X, Y)
iterations += 1
mean_loss = loss / iterations
epoch += 1
print 'Epoch {} loss = {}'.format(epoch, mean_loss)
sample(seed, 150, get_output, seq_length, vocabulary_size, char_to_ix, ix_to_char)
except KeyboardInterrupt:
print 'CTRL+C detected: exiting...'
if __name__ == '__main__':
