Skip to content

Instantly share code, notes, and snippets.

@mcaounfb
Created December 18, 2017 18:06
Show Gist options
  • Save mcaounfb/7ba05b0a62383c36e24a33defa3f11aa to your computer and use it in GitHub Desktop.
Save mcaounfb/7ba05b0a62383c36e24a33defa3f11aa to your computer and use it in GitHub Desktop.
RNN Tensorflow Simple implementation
import tensorflow as tf
import numpy as np
import collections
from tqdm import tqdm
num_steps = 10
hidden_size = 20
num_epochs = 20
minibatch_size = 10
vocab_size = 500
max_num_words = 10000
learning_rate = .1 * minibatch_size
def _read_words(filename):
with tf.gfile.GFile(filename, "r") as f:
return f.read().replace("\n", "<eos>").split()
def _build_vocab(filename):
data = _read_words(filename)
data = data[0:max_num_words]
counter = collections.Counter(data)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
word_to_id = dict(zip(words, range(len(words))))
return word_to_id
def _file_to_word_ids(filename, word_to_id):
data = _read_words(filename)
data = data[0:max_num_words]
return [min(word_to_id[word], vocab_size - 1) for word in data if word in word_to_id]
train_path = 'ptb/ptb.train.txt'
word_to_id = _build_vocab(train_path)
train_data = _file_to_word_ids(train_path, word_to_id)
vocabulary = len(word_to_id)
num_words = len(train_data)
num_sequences = int(num_words / num_steps) - 1
word_counter_id = 0
train_inputs_node = tf.placeholder(tf.float32, shape=(minibatch_size, num_steps, vocab_size))
train_targets_node = tf.placeholder(tf.int32, shape=(minibatch_size, num_steps))
def build_rnn_graph(inputs, is_training):
cell = tf.contrib.rnn.BasicRNNCell(
hidden_size, reuse=not is_training)
state = cell.zero_state(minibatch_size, tf.float32)
outputs = []
with tf.variable_scope("RNN"):
for time_step in range(num_steps):
if time_step > 0: tf.get_variable_scope().reuse_variables()
(cell_output, state) = cell(inputs[:, time_step, :], state)
outputs.append(cell_output)
output = tf.reshape(tf.concat(outputs, 1), [-1, hidden_size])
return output
def model(data, train=False):
output = build_rnn_graph(data, train)
softmax_w = tf.get_variable(
"softmax_w", [hidden_size, vocab_size], dtype=tf.float32)
softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=tf.float32)
logits = tf.nn.xw_plus_b(output, softmax_w, softmax_b)
logits = tf.reshape(logits, [minibatch_size * num_steps, vocab_size])
return logits
with tf.variable_scope("Model", reuse=None):
logits = model(train_inputs_node, True)
train_targets_unrolled = tf.reshape(
train_targets_node, shape=(minibatch_size * num_steps,))
cost_node = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits,
labels=train_targets_unrolled)
cost_node = tf.reduce_mean(cost_node)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(cost_node)
sv = tf.train.Supervisor()
config_proto = tf.ConfigProto()
with sv.managed_session(config=config_proto) as sess:
train_size = len(train_data)
batch_len = train_size // minibatch_size
num_batches = (batch_len - 1) // num_steps
num_sequences = minibatch_size * num_batches
for epoch in range(num_epochs):
print("Training - epoch #{}".format(epoch + 1))
costs = 0
iters = 0
train_data = np.array(train_data[0 : minibatch_size * batch_len])
train_data = np.reshape(train_data, (minibatch_size, batch_len))
for batch in tqdm(range(num_batches)):
batch_inputs = np.zeros((minibatch_size, num_steps, vocab_size), np.float32)
for i in range(minibatch_size):
for j in range(num_steps):
x = train_data[i, batch * num_steps + j]
batch_inputs[i, j, x] = 1
batch_targets = train_data[0:minibatch_size,
(batch * num_steps + 1):((batch + 1) * num_steps + 1)]
feed_dict = {}
feed_dict[train_inputs_node] = batch_inputs
feed_dict[train_targets_node] = batch_targets
fetches_dict = {"optimizer": train_op,
"cost": cost_node,}
vals = sess.run(fetches_dict, feed_dict=feed_dict)
cost = vals["cost"]
costs += cost
iters += 1
perplexity = np.exp(costs / iters)
print("Train Perplexity: %.3f" % perplexity)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment