Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@damienpontifex
Last active May 10, 2020 00:38
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save damienpontifex/74561a9e6bf43b59b813e5487257aa91 to your computer and use it in GitHub Desktop.
Save damienpontifex/74561a9e6bf43b59b813e5487257aa91 to your computer and use it in GitHub Desktop.
Getting my head around TensorFlow RNN inputs, outputs and the appropriate shapes
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# `dynamic_rnn`"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"\n",
"# Values is data batch_size=2, sequence_length = 3, num_features = 1\n",
"values = tf.constant(np.array([\n",
" [[1], [2], [3]],\n",
" [[2], [3], [4]]\n",
"]), dtype=tf.float32)\n",
"\n",
"lstm_cell = tf.contrib.rnn.LSTMCell(100)\n",
"\n",
"outputs, state = tf.nn.dynamic_rnn(cell=lstm_cell, dtype=tf.float32, inputs=values)\n",
"\n",
"with tf.Session() as sess:\n",
" sess.run(tf.global_variables_initializer())\n",
" output_run, state_run = sess.run([outputs, state])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.all(output_run[:,-1] == state_run.h)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'rnn/transpose:0' shape=(2, 3, 100) dtype=float32>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"outputs"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'rnn/while/Exit_2:0' shape=(2, 100) dtype=float32>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state.c"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'rnn/while/Exit_3:0' shape=(2, 100) dtype=float32>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state.h"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# `bidirectional_dynamic_rnn`"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"\n",
"# Values is data batch_size=2, sequence_length = 3, num_features = 1\n",
"values = tf.constant(np.array([\n",
" [[1], [2], [3]],\n",
" [[2], [3], [4]]\n",
"]), dtype=tf.float32)\n",
"\n",
"lstm_cell_fw = tf.contrib.rnn.LSTMCell(100)\n",
"lstm_cell_bw = tf.contrib.rnn.LSTMCell(105) # change to 105 just so can see the effect in output\n",
"\n",
"(output_fw, output_bw), (output_state_fw, output_state_bw) = tf.nn.bidirectional_dynamic_rnn(\n",
" cell_fw=lstm_cell_fw, \n",
" cell_bw=lstm_cell_bw, \n",
" inputs=values,\n",
" dtype=tf.float32)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'bidirectional_rnn/fw/fw/transpose:0' shape=(2, 3, 100) dtype=float32>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output_fw"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'ReverseV2:0' shape=(2, 3, 105) dtype=float32>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output_bw"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'bidirectional_concat_outputs:0' shape=(2, 3, 205) dtype=float32>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"outputs = tf.concat((output_fw, output_bw), axis=2, name='bidirectional_concat_outputs')\n",
"outputs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_2:0' shape=(2, 100) dtype=float32>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output_state_fw.c"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(2, 100) dtype=float32>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output_state_fw.h"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_2:0' shape=(2, 105) dtype=float32>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output_state_bw.c"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_3:0' shape=(2, 105) dtype=float32>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output_state_bw.h"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'bidirectional_concat_memory_cell:0' shape=(2, 205) dtype=float32>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.concat((output_state_fw.c, output_state_bw.c), axis=1, name='bidirectional_concat_memory_cell')"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'bidirectional_concat_hidden_state:0' shape=(2, 205) dtype=float32>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.concat((output_state_fw.h, output_state_bw.h), axis=1, name='bidirectional_concat_hidden_state')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GRU"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"# Values is data batch_size=2, sequence_length = 3, num_features = 1\n",
"values = tf.constant(np.array([\n",
" [[1], [2], [3]],\n",
" [[2], [3], [4]]\n",
"]), dtype=tf.float32)\n",
"gru_cell = tf.contrib.rnn.GRUCell(100)\n",
"outputs, state = tf.nn.dynamic_rnn(cell=gru_cell, dtype=tf.float32, inputs=values)\n",
"\n",
"with tf.Session() as sess:\n",
" sess.run(tf.global_variables_initializer())\n",
" output_run, state_run = sess.run([outputs, state])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.all(output_run[:,-1] == state_run)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'rnn/transpose:0' shape=(2, 3, 100) dtype=float32>"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"outputs"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'rnn/while/Exit_2:0' shape=(2, 100) dtype=float32>"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multi RNN"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"\n",
"# Values is data batch_size=2, sequence_length = 3, num_features = 1\n",
"values = tf.constant(np.array([\n",
" [[1], [2], [3]],\n",
" [[2], [3], [4]]\n",
"]), dtype=tf.float32)\n",
"\n",
"lstm_cell = lambda: tf.contrib.rnn.LSTMCell(100)\n",
"multi_cell = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(3)])\n",
"\n",
"outputs, state = tf.nn.dynamic_rnn(cell=multi_cell, dtype=tf.float32, inputs=values)\n",
"\n",
"with tf.Session() as sess:\n",
" sess.run(tf.global_variables_initializer())\n",
" output_run, state_run = sess.run([outputs, state])"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.all(output_run[:,-1] == state_run[-1].h)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'rnn/transpose:0' shape=(2, 3, 100) dtype=float32>"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"outputs"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_2:0' shape=(2, 100) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_3:0' shape=(2, 100) dtype=float32>),\n",
" LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_4:0' shape=(2, 100) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_5:0' shape=(2, 100) dtype=float32>),\n",
" LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_6:0' shape=(2, 100) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_7:0' shape=(2, 100) dtype=float32>))"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@ashwin4ever
Copy link

Your tutorial is a life saver. Brilliant and point on!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment