Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Simple implementation of LSTM in Tensorflow in 50 lines (+ 130 lines of data generation and comments)
from __future__ import print_function
import numpy as np
import tensorflow as tf
import tensorflow.contrib.layers as layers
def as_bytes(num, final_size):
res = []
for _ in range(final_size):
res.append(num % 2)
num //= 2
return res
def generate_batch(num_bits, batch_size):
x = np.empty((batch_size, num_bits, 2))
y = np.empty((batch_size, num_bits, 1))
for i in range(batch_size):
a = np.random.randint(0, 2**(num_bits - 1) - 1)
b = np.random.randint(0, 2**(num_bits - 1) - 1)
res = a + b
x[i, :, 0] = as_bytes(a, num_bits)
x[i, :, 1] = as_bytes(b, num_bits)
y[i, :, 0] = as_bytes(res,num_bits)
return x, y
### graph
INPUT_SIZE = 2 # 2 bits per timestep
RNN_HIDDEN = 20
OUTPUT_SIZE = 1 # 1 bit per timestep
inputs = tf.placeholder(tf.float32, (None, None, INPUT_SIZE)) # (time, batch, in)
outputs = tf.placeholder(tf.float32, (None, None, OUTPUT_SIZE)) # (time, batch, out)
cell = tf.contrib.rnn.BasicLSTMCell(RNN_HIDDEN, state_is_tuple=True)
rnn_outputs, rnn_states = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32, time_major=False)
raw_outputs = layers.fully_connected(rnn_outputs, OUTPUT_SIZE, activation_fn=None)
predicted_outputs = tf.nn.sigmoid(raw_outputs)
error = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=raw_outputs, labels=outputs))
train_fn = tf.train.AdamOptimizer(learning_rate=0.03).minimize(error)
diff = tf.reduce_mean(tf.abs(outputs - predicted_outputs))
### train
NUM_BITS = 10
ITERATIONS_PER_EPOCH = 100
BATCH_SIZE = 16
valid_x, valid_y = generate_batch(num_bits=NUM_BITS, batch_size=100)
session = tf.Session()
session.run(tf.global_variables_initializer())
for epoch in range(100):
epoch_error = 0
for _ in range(ITERATIONS_PER_EPOCH):
x, y = generate_batch(num_bits=NUM_BITS, batch_size=BATCH_SIZE)
err, _ = session.run([error, train_fn], {inputs: x, outputs: y})
epoch_error += err
epoch_error /= ITERATIONS_PER_EPOCH
valid_diff = session.run(diff, {inputs: valid_x, outputs: valid_y})
print("Epoch {}, train error: {:.6f}, diff: {:.6f}"
.format(epoch, epoch_error, valid_diff))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.