Skip to content

Instantly share code, notes, and snippets.

@iganichev
Last active March 3, 2019 16:09
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save iganichev/632b425fed0263d0274ec5b922aa3b2f to your computer and use it in GitHub Desktop.
Save iganichev/632b425fed0263d0274ec5b922aa3b2f to your computer and use it in GitHub Desktop.
Small example for carrying previous LSTM state into next training batch
#!/usr/bin/env python
""" This example trains an LSTM to predict the next number
in a sequence 0, 0, 1, 1, 0, 0, 1, 1, ...
Training is done on very short sequences of length 3.
Loss is computed on all but the last predicted value for simplicity.
When REMEMBER_STATE is True, the previous LSTM state is transfered
to the next training step and the network reaches almost zero error.
When REMEMBER STATE is False, the network has no way to predict the
second element from the first, so it outputs 0.5. It is still able to
predict the third element.
"""
import random
import numpy as np
import tensorflow as tf
from tensorflow.contrib import rnn
REMEMBER_STATE = True
NUM_HIDDEN = 10
BATCH_SIZE = 100
LENGTH = 3
saved_c = tf.get_variable("saved_c", shape=[BATCH_SIZE, NUM_HIDDEN], dtype=tf.float32)
saved_h = tf.get_variable("saved_h", shape=[BATCH_SIZE, NUM_HIDDEN], dtype=tf.float32)
mlp = tf.layers.Dense(1)
x = tf.placeholder(tf.float32, [BATCH_SIZE, LENGTH, 1])
xs = tf.unstack(x, LENGTH, axis=1)
initial_c=tf.placeholder(tf.float32, [BATCH_SIZE, NUM_HIDDEN])
initial_h=tf.placeholder(tf.float32, [BATCH_SIZE, NUM_HIDDEN])
initial_state = rnn.LSTMStateTuple(c=initial_c, h=initial_h)
cell = rnn.BasicLSTMCell(NUM_HIDDEN)
# outputs - a list of length LENGTH, each element a tensor of shape [BATCH_SIZE, NUM_HIDDEN]
# states - LSTMStateTuple with both c and h having shape [BATCH_SIZE, NUM_HIDDEN]
outputs, states = rnn.static_rnn(cell, inputs=xs, initial_state=initial_state)
assign_c = tf.assign(saved_c, states.c)
assign_h = tf.assign(saved_h, states.h)
with tf.control_dependencies([assign_c, assign_h]):
assign_op = tf.no_op()
def loss(inputs, outputs):
loss = 0
# Predictions for first example in the batch
predictions = []
for output, labels in zip(outputs, tf.unstack(inputs[:, 1:, :], axis=1)):
# prediction shape = [BATCH_SIZE, 1]
prediction = mlp(output)
predictions.append(prediction[0, 0])
loss += tf.sqrt(tf.losses.mean_squared_error(labels=labels,
predictions=prediction))
return loss, tf.stack(predictions)
floss, predictions = loss(x, outputs)
train_op = tf.train.AdamOptimizer().minimize(floss)
def input_gen():
repeats = 2
nums = 2
cache = {}
template = np.repeat(range(nums) * 2 * LENGTH, repeats=repeats)
def numpy_cache(pos):
key = pos % (nums * repeats)
if key not in cache:
cache[key] = template[key:(key + LENGTH)]
return cache[key]
positions = [random.randint(0, nums * repeats) for _ in xrange(BATCH_SIZE)]
def get_input():
result = np.zeros([BATCH_SIZE, LENGTH, 1])
for i in xrange(BATCH_SIZE):
result[i, :, 0] = numpy_cache(positions[i])
positions[i] += LENGTH
return result
return get_input
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
c_val = np.zeros([BATCH_SIZE, NUM_HIDDEN])
h_val = np.zeros([BATCH_SIZE, NUM_HIDDEN])
generator = input_gen()
losses = []
for i in xrange(10000):
x_val = generator()
if REMEMBER_STATE:
c_val, h_val, loss_val, _, _, first_pred = sess.run(
[saved_c, saved_h, floss, train_op, assign_op, predictions],
feed_dict={x: x_val,
initial_c: c_val,
initial_h: h_val})
else:
loss_val, _, first_pred = sess.run(
[floss, train_op, predictions],
feed_dict={x: x_val,
initial_c: c_val,
initial_h: h_val})
losses.append(loss_val)
if i % 101 == 0 and len(losses) >= 100:
print "iteration:", i
print "loss:", sum(losses[-100:]) / 100.0
print "predictions on first example:", first_pred
print "input:", x_val[0, :, 0]
print
@domebianchi
Copy link

Error is given in row 68
template = np.repeat(range(nums) * 2 * LENGTH, repeats=repeats)
TypeError: unsupported operand type(s) for *: 'range' and 'int'
Is it right? How can you fix it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment