Created
October 22, 2017 12:58
-
-
Save davesnowdon/bf0e085c4c62328db5b255e6766a984a to your computer and use it in GitHub Desktop.
Text generator in python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_inputs(): | |
""" | |
Create TF Placeholders for input, targets, and learning rate. | |
""" | |
# TODO: Implement Function | |
return tf.placeholder(tf.int32, shape=(None, None), name="input"), tf.placeholder(tf.int32, shape=(None, None)), tf.placeholder(tf.float32) | |
def get_init_cell(batch_size, rnn_size): | |
""" | |
Create an RNN Cell and initialize it. | |
""" | |
lstm_layers = 1 | |
lstm = tf.contrib.rnn.BasicLSTMCell(rnn_size) | |
cell = tf.contrib.rnn.MultiRNNCell([lstm] * lstm_layers) | |
initial_state = tf.identity(cell.zero_state(batch_size, tf.float32), name='initial_state') | |
return cell, initial_state | |
def build_nn(cell, rnn_size, input_data, vocab_size): | |
""" | |
Build part of the neural network | |
""" | |
embed = tf.contrib.layers.embed_sequence(input_data, vocab_size, 300) | |
outputs, final_state = tf.nn.dynamic_rnn(cell, embed, dtype=tf.float32) | |
final_state = tf.identity(final_state, name='final_state') | |
logits = tf.contrib.layers.fully_connected(outputs, num_outputs=vocab_size, activation_fn=None) | |
return logits, final_state | |
def get_batches(int_text, batch_size, seq_length): | |
""" | |
Return batches of input and target | |
:param int_text: Text with the words replaced by their ids | |
:param batch_size: The size of batch | |
:param seq_length: The length of sequence | |
:return: Batches as a Numpy array | |
""" | |
print("len = {}, batch_size = {}, seq length = {}".format(len(int_text), batch_size, seq_length)) | |
n_batches = len(int_text)//(batch_size*seq_length) | |
print("n batches = {}".format(n_batches)) | |
batches = np.zeros((n_batches, 2, batch_size, seq_length), dtype=np.int) | |
for b in range(0, n_batches): | |
inputs, targets = [], [] | |
for i in range(0, batch_size): | |
ii = b*batch_size*seq_length + i*seq_length | |
batches[b][0][i][0:seq_length] = np.copy(int_text[ii:ii+seq_length]) | |
batches[b][1][i][0:seq_length] = np.copy(int_text[ii+1:ii+seq_length+1]) | |
return batches | |
# Number of Epochs | |
num_epochs = 256 | |
# Batch Size | |
batch_size = 100 | |
# RNN Size | |
rnn_size = 512 | |
# Sequence Length | |
seq_length = 200 | |
# Learning Rate | |
learning_rate = 0.1 | |
# Show stats for every n number of batches | |
show_every_n_batches = 10 | |
train_graph = tf.Graph() | |
with train_graph.as_default(): | |
vocab_size = len(int_to_vocab) | |
input_text, targets, lr = get_inputs() | |
input_data_shape = tf.shape(input_text) | |
cell, initial_state = get_init_cell(input_data_shape[0], rnn_size) | |
logits, final_state = build_nn(cell, rnn_size, input_text, vocab_size) | |
# Probabilities for generating words | |
probs = tf.nn.softmax(logits, name='probs') | |
# Loss function | |
cost = seq2seq.sequence_loss( | |
logits, | |
targets, | |
tf.ones([input_data_shape[0], input_data_shape[1]])) | |
# Optimizer | |
optimizer = tf.train.AdamOptimizer(lr) | |
# Gradient Clipping | |
gradients = optimizer.compute_gradients(cost) | |
capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gradients] | |
train_op = optimizer.apply_gradients(capped_gradients) | |
batches = get_batches(int_text, batch_size, seq_length) | |
with tf.Session(graph=train_graph) as sess: | |
sess.run(tf.global_variables_initializer()) | |
for epoch_i in range(num_epochs): | |
state = sess.run(initial_state, {input_text: batches[0][0]}) | |
for batch_i, (x, y) in enumerate(batches): | |
feed = { | |
input_text: x, | |
targets: y, | |
initial_state: state, | |
lr: learning_rate} | |
train_loss, state, _ = sess.run([cost, final_state, train_op], feed) | |
# Show every <show_every_n_batches> batches | |
if (epoch_i * len(batches) + batch_i) % show_every_n_batches == 0: | |
print('Epoch {:>3} Batch {:>4}/{} train_loss = {:.3f}'.format( | |
epoch_i, batch_i, len(batches), train_loss)) | |
# Save Model | |
saver = tf.train.Saver() | |
saver.save(sess, save_dir) | |
print('Model Trained and Saved') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment