Skip to content

Instantly share code, notes, and snippets.

@stefanthaler
Created February 22, 2017 16:13
Show Gist options
  • Save stefanthaler/2f3e82df164e14a94bd2dc0afb5b2dea to your computer and use it in GitHub Desktop.
Save stefanthaler/2f3e82df164e14a94bd2dc0afb5b2dea to your computer and use it in GitHub Desktop.
Tensorflow Dynamic RNN explained
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# display ipython notebook full width\n",
"from IPython.core.display import display, HTML\n",
"display(HTML(\"<style>.container { width:100% !important; }</style>\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tensorflow Dynamic RNN unrolled\n",
"\n",
"This is an unrolled example of TensorFlow's dynamic RNN. I made this example, because the TensorFlow internal code is hard to read and understand, because it considers many different cases and optimizations . \n",
"\n",
"The original code of the dynamic_rnn can be found here: https://github.com/tensorflow/tensorflow/blob/v1.0.0/tensorflow/python/ops/rnn.py#L380\n",
"\n",
"I have stripped away everything that is unnecessary for understanding the code. In order to increase the readability I have fixed a some parameters, for example that the cell has to be an LSTM cell or that data type of the network is set globally. I have also violated a great many coding conventions and some of the TensorFlow best practices to increase the understandability of the code - this is by no means production ready code. \n",
"\n",
"## Structure of this notebook\n",
"\n",
"A deep leanrning problem can be coarsely divided into the following three components: A task, data, and an objective. In Section 0, I briefly introduce the learning task. In secion 1, I define hyperparameters and python imports for this approach. In Section 2 I introduce the dynamic_rnn function that was taken from the original code and stripped away from all unnecessary details. Section 3 defines the tensorflow graph, which uses the dynamic_rnn function. Section 2 and 3 together describe the learning task. Section 4 describes the learning objective, i.e. in our case whether the correct number of ones was predicted or not. Section 5 describes our data - since it is an artificial task, we can randomly generate it. In Section 6, we execute the whole learning."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 0 The Task\n",
"\n",
"To demonstrate the dynamic rnn, I have created a simple toy example. The task that the dynamic RNN is supposed to learn is to count the number of ones in an input sequence correctly. In other words, this is a many-to-one sequence classification problem, where you have a sequence of input vectors, which are passed through an RNN and the final output vector is used for classification. The class is the correct number of ones. \n",
"\n",
"If sequences of different length are put together in a batch, they are dynamically zero padded to fit the longest sequence in the batch.\n",
"\n",
"If the longest sequence in a batch is 4, a trainings batch could look like this:\n",
"```\n",
"x=[\n",
" [1,1,0,0] # =>2 \n",
" [1,1,1,1] # =>4\n",
"]\n",
"```\n",
"\n",
"If the longest sequence in a batch is 3, a trainings batch could look like this:\n",
"```\n",
"x=[\n",
" [1,0,0] # =>1 \n",
" [1,1,1] # =>3\n",
"]\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1 Hyper parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"assert tf.__version__ == '1.0.0-rc0'# this can be changed, this is just a safeguard for having a version > than this version\n",
"\n",
"from tensorflow.contrib import rnn as contrib_rnn\n",
"from tensorflow.python.util import nest # for converting list to LSTM array tuple "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# trainingsdata\n",
"num_examples = 50000\n",
"\n",
"# graph\n",
"state_size = 64\n",
"target_classes = 15 # maximum sequence length\n",
"batch_size = 50\n",
"dtype = tf.float32\n",
"parallel_iterations = 32 # how many loops to process in parallel\n",
"\n",
"# cost function\n",
"learning_rate = 0.01\n",
"\n",
"# training\n",
"checkpoint_path = \"dynamic_rnn\"\n",
"max_steps = 1000 # how many batches to train\n",
"save_checkpoint_after_each_step = 200\n",
"add_summary_after_each_step = 20"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2 Simplified Dynamic RNN\n",
"\n",
"In pseudocode, what the calculations of a dynamic rnn loop in tensorflow is fairly simple. Let's assumpe an RNN that is unrolled for 4 timesteps. Let's first assume we train on one trainingsexample with a given sequence length = 3. The maximum possible sequence length is 4. \n",
"\n",
"## Pseudo Code for batch_size = 1\n",
"\n",
"```python\n",
"def rnn_cell_function(input, state)\n",
" # the rnn function that we want to learn\n",
" # rnn internal calculations\n",
" return output, new_state\n",
" \n",
"max_sequence_length = 4\n",
"trainings_example_sequence_length = 2\n",
"inputs = [ input_time_0, input_time_1 , input_time_2 , input_time_3 ] # the input at each timestep has the shape ( batch_size, input_vector size ) \n",
"outputs = [ _ , _ , _, ] # the output at each timestep has the shape ( batch_size, output_vector_size ) \n",
"\n",
"current_time_step=0\n",
"current_state = initial_state # we got this from our rnn cell, for example a zero state. \n",
"while current_time_step<=max_sequence_length:\n",
" input_at_time_i = inputs[current_time_step] # get input for timestep i\n",
" \n",
" if current_time_step<=trainings_example_sequence_length: # only calculate outputs if our trainingsexample is long enough\n",
" output_at_time_i, new_state = rnn_cell_function(input_at_time_i, current_state ) # calculate output and new state i. output has shape ( batch_size, output_vector_size ) \n",
" else: # in this case our trainingsexample is not long enough, output zeros. \n",
" output_at_time_i = zeros # shape = ( batch_size, output_vector_size ) \n",
" \n",
" outputs[current_time_step] = output_at_time_i # store output of our function\n",
" \n",
" # update loop variables\n",
" current_time_step+=1 # \n",
" current_state = new_state\n",
" \n",
"```\n",
"\n",
"## Pseudo Code for batch_size = 2\n",
"\n",
"The previous example has only one example in to train on. If you had more examples, there would be one sequence length per training example. TensorFlow elegantly hides the gory implementation details of the required matrix multiplication for a batch with more than one example. The matrix multiplications are hard to imagine, but the concept can be explained by using another for loop for each element of the batch instead. \n",
"\n",
"```python\n",
"trainings_example_sequence_lengths = [2,4]\n",
"batch_size = 2\n",
"\n",
"# same setup as before\n",
"current_state = initial_state # shape of state = (batch_size, state_size)\n",
"while current_time_step<=max_sequence_length: # \n",
" \n",
" for example_id in xrange(batch_size): #calculate updates for each element of the batch\n",
" # get inputs and sequence lengths\n",
" current_example_sequence_length = trainings_example_sequence_lengths[example_id]\n",
" current_example_cell_input = inputs[current_time_step][example_id] # cell_input shape = (1, cell_input_size ) \n",
" \n",
" if current_time_step<=trainings_example_sequence_length: # only calculate outputs if this trainingsexample is long enough\n",
" output_current_example, new_state = rnn_cell_function(current_example_cell_input , current_state[example_id] ) # calculate output and new state i. output has shape ( 1 , output_vector_size ) \n",
" else: # in this case this trainingsexample is not long enough, output zeros. \n",
" output_current_example = zeros # shape = (1, output_vector_size ) \n",
" \n",
" outputs[current_time_step][example_id] = output_current_example # store output of our function for THIS trainings example\n",
" current_state[example_id] = new_state # update state for THIS trainings example. only maintain one state per example\n",
" \n",
" # after we have processed the batch for one timestep, repeat for the next timestep\n",
" current_time_step+=1 \n",
"```\n",
"\n",
"## Note about abstraction\n",
"\n",
"Apart from that TensorFlow hides more details than previoiusly mentioned. It hides all the gory details of memory management, backpropagation, parallel computation and computations on different devices. Pretty impressive. \n",
"\n",
"## TF functions Documentation:\n",
"* tf.stack https://www.tensorflow.org/api_docs/python/tf/stack\n",
"* tf.unstack https://www.tensorflow.org/api_docs/python/tf/unstack\n",
"* tf.variable_scope https://www.tensorflow.org/api_docs/python/tf/variable_scope\n",
"* tf.while_loop https://www.tensorflow.org/api_docs/python/tf/while_loop\n",
"* tf.zeros https://www.tensorflow.org/api_docs/python/tf/zeros\n",
"* ff.cond https://www.tensorflow.org/api_docs/python/tf/cond\n",
"* tf.transpose https://www.tensorflow.org/api_docs/python/tf/transpose\n",
"* tf.TensorArray https://www.tensorflow.org/api_docs/python/tf/TensorArray\n",
"* tf.TensorArray#read https://www.tensorflow.org/api_docs/python/tf/TensorArray#read\n",
"* tf.TensorArray#write https://www.tensorflow.org/api_docs/python/tf/TensorArray#write\n",
"* tf.constant https://www.tensorflow.org/api_docs/python/tf/constant\n",
"* tf.nest.pack_sequence_as https://github.com/tensorflow/tensorflow/blob/v1.0.0/tensorflow/python/util/nest.py#L228\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def dynamic_rnn(\n",
" cell, # the rnn cell function\n",
" inputs, # time major inputs [max_time_steps , batch_size, cell_input_size ]\n",
" sequence_length=None, # [ batch_size ] the sequence lengths of each trainings example\n",
" initial_state=None): # zero state\n",
"\n",
" with tf.variable_scope(\"Dynamic_RNN\") as scope:\n",
"\n",
" state = initial_state\n",
" state_size = cell.state_size\n",
" # sequence length\n",
" max_sequence_length = tf.reduce_max(sequence_length) # max number of time_steps for this batch\n",
"\n",
" # save shape information for shape restoration after loop\n",
" inputs_shape = inputs.get_shape().with_rank_at_least(3) # [time_steps, batch_size, cell_input_size ]\n",
" batch_time_steps = inputs_shape[0]\n",
"\n",
" \"\"\"\n",
" prepare dynamic rnn inputs. \n",
" \"\"\"\n",
" # input_ta is a list of tensor. each tensor has the shape (batch_size,cell_input_size)\n",
" # the length of input_ta is the maximum number of timesteps in that batch \n",
" input_ta = tf.TensorArray(dtype=dtype, size=max_sequence_length, tensor_array_name=str(scope) + \"input_tensors\")\n",
" input_ta = input_ta.unstack(inputs) # store inputs in tensor array\n",
" \n",
" \"\"\"\n",
" prepare dynamic rnn outputs\n",
" \"\"\"\n",
" # output_ta is a list tensors. each tensor has the shape (batch_size,cell_output_size)\n",
" # len(output_ta) = len(input_ta)\n",
" output_ta = tf.TensorArray(dtype=dtype, size=max_sequence_length, tensor_array_name=str(scope) + \"output_tensors\") \n",
" # prepare zero output\n",
" # the dynamic rnn loop produces zero outputs when a sequence has ended\n",
" zero_output = tf.zeros( shape=(batch_size, cell.output_size) , dtype=dtype) \n",
" \n",
" # variable for keeping track at which timestep in the rnn we are\n",
" current_time_step = tf.constant(0, dtype=tf.int32, name=\"current_time_step\")\n",
"\n",
" \"\"\"\n",
" Define the loop body\n",
" \n",
" Calculates the outputs / new states for the rnn at the time step 'current_time_step'.\n",
" \"\"\" \n",
" def calculate_time_step(current_time_step, output_ta_t, input_state): # calculate outputs for one timestep (get next hidden state and next output).\n",
" # get input for current timestep\n",
" input_t = input_ta.read(current_time_step)\n",
" input_t.set_shape( inputs_shape[1:] ) # input_t has shape [batch_size, cell_input_size]\n",
"\n",
" def call_cell():\n",
" cell_output, new_state = cell(input_t, input_state) # invoke rnn cell\n",
" return [ cell_output, new_state.c , new_state.h ] \n",
"\n",
" def empty_update():\n",
" return [ zero_output, input_state.c, input_state.h ]\n",
"\n",
" output_t, new_state_c, new_state_h = tf.cond(\n",
" current_time_step >= max_sequence_length, # if current sequence is too long\n",
" empty_update, # do nothing\n",
" call_cell # otherwise calculation is required: copy some or all of it through\n",
" )\n",
"\n",
" # restore shape information of output and states\n",
" output_t.set_shape(zero_output.get_shape())\n",
" new_state_c.set_shape( state.c.get_shape() )\n",
" new_state_h.set_shape( state.h.get_shape() )\n",
" \n",
" # save output at the correct index of the output tensor array\n",
" output_ta_t = output_ta_t.write(current_time_step, output_t) # \n",
"\n",
" # convert new state array back to LSTM state tuple \n",
" new_state = nest.pack_sequence_as( structure=state, flat_sequence=[new_state_c, new_state_h])\n",
"\n",
" # update time step\n",
" next_time_step = current_time_step + 1\n",
" return (next_time_step, output_ta_t, new_state)\n",
" \"\"\"\n",
" define the while loop\n",
" \"\"\"\n",
" _, output_final_ta, final_state = tf.while_loop(\n",
" cond=lambda current_time_step, *_: current_time_step < max_sequence_length,\n",
" body=calculate_time_step,\n",
" loop_vars=(current_time_step, output_ta, state),\n",
" parallel_iterations=parallel_iterations,\n",
" )\n",
" \n",
" \"\"\"\n",
" postprocess output\n",
" \"\"\"\n",
" final_outputs = output_final_ta.stack() # convert list to tensors, e.g. for two timesteps [ (batch_size, cell_output_size), (batch_size, cell_output_size) ] => (2,batch_size, cell_output_size )\n",
" final_outputs.set_shape( shape = (batch_time_steps, batch_size, cell.output_size ) ) # restore shape information \n",
" final_outputs_batchmajor = tf.transpose(final_outputs, [1, 0, 2]) # convert outputs to batch major format (time_steps, batch_size, cell_output_size) => (batch_size, time_steps, cell_output_size) \n",
" \n",
" return (final_outputs_batchmajor, final_state)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3 Graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"tf_global_step = tf.Variable(0, trainable=False)\n",
"\n",
"# Inputs + Targets\n",
"with tf.variable_scope(\"Inputs\") as input_scope:\n",
" inputs_x = tf.placeholder(tf.int32, [batch_size, None], name=\"X\") # encoder inputs\n",
" input_embedding = tf.get_variable('input_embedding', [2, 2]) # each row is a dense vector for each word.\n",
" inputs_x_embedded = tf.nn.embedding_lookup(input_embedding, inputs_x) \n",
" # change input shape from batch_major to time major\n",
" # input in batchmajor is a tensor [batch_size, time_steps, input_vector_size ]\n",
" # input in timemajor is a tensor [ timesteps, batch_size, input_vector_size ]\n",
" # time_steps is the maximum number of timesteps for this batch\n",
" transposed_input = tf.transpose(inputs_x_embedded, [1, 0, 2], name=\"transposed_inputs\")\n",
" \n",
" \n",
" sequence_lengths = tf.placeholder(tf.int32, [batch_size], name=\"SequenceLengths\") # the length of the sequence\n",
" \n",
"with tf.variable_scope(\"Targets\") as target_scope:\n",
" target_y = tf.placeholder(tf.int32, [batch_size], \"Y\") # [batch_size] # the length of the sequence\n",
" target_y_onehot = tf.one_hot(indices=target_y, # [batch_size, target_classes ] \n",
" depth=target_classes, \n",
" on_value=1, off_value=0, \n",
" axis=None, dtype=None, name=\"Y_onehot\")\n",
" \n",
"\n",
"# Define cells\n",
"with tf.variable_scope(\"Encoder\") as encoder_scope:\n",
"\n",
" cell = contrib_rnn.LSTMCell(num_units=state_size, state_is_tuple=True)\n",
"\n",
" zero_state = cell.zero_state(batch_size, dtype)\n",
" \n",
" # encode inputs\n",
" outputs, last_hidden_state = dynamic_rnn(\n",
" cell=cell,\n",
" inputs=transposed_input, # [max_time_steps , batch_size, cell_input_size ]\n",
" sequence_length=sequence_lengths,\n",
" initial_state=zero_state,\n",
" #dtype=dtype, # tf.float64\n",
" #parallel_iterations=32, \n",
" #swap_memory=False, is cpu gpu memory swap enabled?\n",
" #time_major=False, \n",
" )\n",
" \n",
" # outputs are of the form: `[max_time_steps, batch_size, cell.output_size]`.\n",
" \n",
" # get last output\n",
" stacked = tf.stack([tf.range(batch_size), sequence_lengths-1]) \n",
" indices = tf.transpose(stacked)\n",
" last_output = tf.gather_nd(outputs, indices) # [batch_size, cell.output_size ]\n",
" \n",
" \n",
" # map last output to one hot\n",
"with tf.variable_scope(\"SoftmaxProjection\") as softmax_projection_scope:\n",
" weights = tf.Variable(tf.truncated_normal([state_size, target_classes], stddev=0.1))\n",
" bias = tf.Variable(tf.constant(0.1, shape=[target_classes]))\n",
" predicted_y_onehot = tf.nn.softmax(tf.matmul(last_output, weights) + bias) # [batch_size, target_classes]\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4 Cost function, optimizer, trainings step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# cost function\n",
"loss = tf.nn.sparse_softmax_cross_entropy_with_logits( \n",
" logits=predicted_y_onehot, \n",
" labels=target_y # takes care of theone hot encoding\n",
")\n",
"cross_entropy_loss = tf.reduce_sum(loss)\n",
"\n",
"# get gradients for all trainable parameters with respect to our loss funciton\n",
"params = tf.trainable_variables()\n",
"optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate)\n",
"gradients = tf.gradients(cross_entropy_loss, params)\n",
"\n",
"# Do something to the gradients here, i.e. gradient clipping\n",
"\n",
"# Update operation\n",
"training_step = optimizer.apply_gradients(zip(gradients, params), global_step=tf_global_step)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5 Trainings Example Generation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def get_batch_dict(batch_size):\n",
" data_x = [] # sequences \n",
" data_y = [] # targets\n",
" data_s = np.random.randint(3,target_classes,batch_size) # sequence length\n",
" \n",
" max_len = np.amax(data_s)\n",
" \n",
" for sl in data_s:\n",
" data_x.append([1] * (sl) + [0]*(max_len-sl) ) #e.g [1,1,1,0,0]\n",
" data_y.append(sl) # e.g. 3 \n",
" \n",
" #return data_x, data_s, data_y\n",
" \n",
" return {\n",
" inputs_x:data_x,\n",
" sequence_lengths:data_y,\n",
" target_y:data_y\n",
" }\n",
"\n",
"def evaluate_batch(results):\n",
" correct = 0 \n",
" for i in xrange(batch_size):\n",
" predicted = np.argmax(results[3][i])\n",
" target = batch_dict[target_y][i]\n",
" if predicted == target: correct += 1 \n",
" print(\"Accuracy: {}\".format(correct/float(batch_size))) \n",
" return correct/batch_size\n",
"\n",
"# print sample batch\n",
"sample_batch = get_batch_dict(3)\n",
"sample_batch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 6 Trainings Session\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import os\n",
"if not os.path.exists(checkpoint_path): os.mkdir(checkpoint_path)\n",
"\n",
"# Saver\n",
"saver = tf.train.Saver(tf.global_variables())\n",
"\n",
"# Summaries\n",
"tf.summary.scalar(\"cross_entropy_loss\",tf.cast(cross_entropy_loss, tf.float32)) # summary for Cross Entropy Loss\n",
"\n",
"# Start session\n",
"session = tf.Session()\n",
"\n",
"all_summaries = tf.summary.merge_all()\n",
"summary_writer = tf.summary.FileWriter(checkpoint_path, graph=session.graph)\n",
"\n",
"# Initialize Varibles\n",
"session.run([\n",
" tf.local_variables_initializer(),\n",
" tf.global_variables_initializer(),\n",
" ])\n",
"\n",
"for current_step in xrange(1,max_steps+1): # start from 1 .. max_steps+1 to execute max steps \n",
" \n",
" # increase step counter\n",
" session.run(tf_global_step.assign(current_step)) \n",
"\n",
" # generate target batch \n",
" batch_dict=get_batch_dict(batch_size)\n",
" \n",
" # execute operations\n",
" results = session.run([\n",
" cross_entropy_loss, # calculate training loss \n",
" training_step, # calculate gradients, update gradients\n",
" all_summaries, # compile summaries and write to graph dir\n",
" predicted_y_onehot # what is actually predicted\n",
" ], feed_dict = batch_dict )\n",
" \n",
" # store loss\n",
" if current_step % add_summary_after_each_step == 0:\n",
" summary_writer.add_summary(results[2],current_step) \n",
" \n",
" # save checkpoint every xth step\n",
" if current_step % save_checkpoint_after_each_step==0:\n",
" print (\"Saving checkpoint\")\n",
" print(\"Step %i, Loss: \"%(current_step) + str(results[0]))\n",
" correct = evaluate_batch(results)\n",
" chkpoint_out_filename = os.path.join(checkpoint_path, \"DynamicRNN\")\n",
" saver.save(session, chkpoint_out_filename , global_step=current_step) \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12+"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment