Skip to content

Instantly share code, notes, and snippets.

@rsk2327
Created March 21, 2020 00:19
Show Gist options
  • Save rsk2327/b6160fb0c2bf736b64327fe818bf9985 to your computer and use it in GitHub Desktop.
Save rsk2327/b6160fb0c2bf736b64327fe818bf9985 to your computer and use it in GitHub Desktop.
BiDirectional RNN
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false",
"colab_type": "text",
"id": "8XYJH7ufTlcd"
},
"source": [
"## **Building a Bi-Directional RNN**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false",
"colab": {},
"colab_type": "code",
"id": "3jgVpIJDTk0Z"
},
"outputs": [],
"source": [
"# Defining the RNN layer\n",
"rnn= nn.RNN(input_size=3, hidden_size=2, num_layers = 1, bias = True, batch_first=True, bidirectional=True)"
]
},
{
"cell_type": "code",
"execution_count": 167,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"colab_type": "code",
"id": "Jhh3b2OwTvim",
"outputId": "48e819da-4dd2-4364-e585-b6be06670d33"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Out all shape : torch.Size([1, 4, 4])\n",
"Out last shape : torch.Size([2, 1, 2])\n"
]
}
],
"source": [
"out_all, out_last = rnn(seq)\n",
"\n",
"print(f\"Out all shape : {out_all.shape}\")\n",
"\n",
"print(f\"Out last shape : {out_last.shape}\")"
]
},
{
"cell_type": "code",
"execution_count": 168,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"colab_type": "code",
"id": "7TsBEhHQTx0w",
"outputId": "2aa4712d-f1e1-40e7-fb2d-9c8b03f8d7e5"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0.8117, -0.4865, 0.4549, 0.4985],\n",
" [-0.0333, -0.6678, 0.3855, 0.6282],\n",
" [ 0.4834, -0.8639, 0.1292, 0.3069],\n",
" [-0.3773, -0.8821, 0.2727, 0.7177]]], grad_fn=<TransposeBackward1>)"
]
},
"execution_count": 168,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"out_all"
]
},
{
"cell_type": "code",
"execution_count": 169,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"colab_type": "code",
"id": "dLgB31A3T8yq",
"outputId": "359d2bef-aefb-40c3-bd21-021cea3a206c"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-0.3773, -0.8821]],\n",
"\n",
" [[ 0.4549, 0.4985]]], grad_fn=<StackBackward>)"
]
},
"execution_count": 169,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"out_last"
]
},
{
"cell_type": "code",
"execution_count": 170,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 221
},
"colab_type": "code",
"id": "rqHXEQ3GUAHb",
"outputId": "0c64d880-afe5-47d3-fb33-c197ca49fd38"
},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('weight_ih_l0', tensor([[ 0.5186, -0.4766, 0.1410],\n",
" [ 0.2750, -0.6602, -0.6266]])),\n",
" ('weight_hh_l0', tensor([[-0.6757, 0.2885],\n",
" [ 0.2265, -0.4132]])),\n",
" ('bias_ih_l0', tensor([0.6045, 0.0802])),\n",
" ('bias_hh_l0', tensor([0.3446, 0.4002])),\n",
" ('weight_ih_l0_reverse', tensor([[-0.1216, -0.1432, 0.3163],\n",
" [-0.6950, 0.2082, 0.1613]])),\n",
" ('weight_hh_l0_reverse', tensor([[-0.3441, 0.0915],\n",
" [-0.2372, 0.2422]])),\n",
" ('bias_ih_l0_reverse', tensor([0.3067, 0.6804])),\n",
" ('bias_hh_l0_reverse', tensor([0.2079, 0.1318]))])"
]
},
"execution_count": 170,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"rnn.state_dict()"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false",
"colab_type": "text",
"id": "zt_SvzJtUrN0"
},
"source": [
"### **Computing outputs - Forward Direction** \n",
"\n",
"For a bidirectional RNN layer with a hidden layer size of 2 and an input sequence of length 4, we get an output of size 4x4.\n",
"\n",
"In the output, each row essentially captures the hidden state corresponding to a given time-stamp. In the previous example, each time stamp was represented by a vector of length 2 (because `hidden_size` = 2). Now, since its bidirectional, each hidden state is represented by a vector of length 4 ( 2 + 2)\n",
"\n",
"\n",
"For each timestamp, the first 2 values correspond to the forward run of the RNN and the last 2 values correspond to the backward run of the RNN."
]
},
{
"cell_type": "code",
"execution_count": 193,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"colab_type": "code",
"id": "TMNnHv21fjIf",
"outputId": "a0e93fa6-edc2-4c7e-bfb7-a74ade4cdf4d"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0.8117, -0.4865, 0.4549, 0.4985],\n",
" [-0.0333, -0.6678, 0.3855, 0.6282],\n",
" [ 0.4834, -0.8639, 0.1292, 0.3069],\n",
" [-0.3773, -0.8821, 0.2727, 0.7177]]], grad_fn=<TransposeBackward1>)"
]
},
"execution_count": 193,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"out_all"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false",
"colab_type": "text",
"id": "_R71xsGYUx3p"
},
"source": [
"#### Hidden State 1 - Forward Direction"
]
},
{
"cell_type": "code",
"execution_count": 194,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "90GtkAzKUlU3",
"outputId": "d17db2c4-70ce-4a16-f59f-70b490a24f13"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.8117, -0.4865]], grad_fn=<TanhBackward>)"
]
},
"execution_count": 194,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"wih = rnn.weight_ih_l0\n",
"whh = rnn.weight_hh_l0\n",
"\n",
"bih = rnn.bias_ih_l0\n",
"bhh = rnn.bias_hh_l0\n",
"\n",
"# We represent all reverse weights using a '_' suffix\n",
"wih_ = rnn.weight_ih_l0_reverse\n",
"whh_ = rnn.weight_hh_l0_reverse\n",
"\n",
"bih_ = rnn.bias_ih_l0_reverse\n",
"bhh_ = rnn.bias_hh_l0_reverse\n",
"\n",
"x = seq[0][0] # The first input feature of the first sequence\n",
"\n",
"# Computing thw hidden state for time = 1\n",
"h1 = torch.tanh(Tensor(matmul(x,wih.T) + bih + matmul( torch.zeros([1,2]) , whh.T ) + bhh)) \n",
"h1\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false",
"colab_type": "text",
"id": "af01UA3Oqhqd"
},
"source": [
"#### Computing all states - Forward Direction"
]
},
{
"cell_type": "code",
"execution_count": 174,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"colab_type": "code",
"id": "ykBJY-6VVFdg",
"outputId": "dc4482f4-4fad-4c69-d383-69581449db91"
},
"outputs": [
{
"data": {
"text/plain": [
"[tensor([[ 0.8117, -0.4865]], grad_fn=<TanhBackward>),\n",
" tensor([[-0.0333, -0.6678]], grad_fn=<TanhBackward>),\n",
" tensor([[ 0.4834, -0.8639]], grad_fn=<TanhBackward>),\n",
" tensor([[-0.3773, -0.8821]], grad_fn=<TanhBackward>)]"
]
},
"execution_count": 174,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"output = []\n",
"\n",
"h_previous = torch.zeros([1,2]) # Since the hidden_size parameter is 2, all hidden states will have a shape of [1,2]\n",
"\n",
"for i in range(seq.shape[1]):\n",
"\n",
" x = seq[0][i]\n",
" h_current = torch.tanh(Tensor(matmul(x,wih.T) + bih + matmul(h_previous,whh.T) + bhh))\n",
" h_previous = h_current\n",
" output.append(h_current)\n",
"\n",
"\n",
"output"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false",
"colab_type": "text",
"id": "bW69y2okraB3"
},
"source": [
"At this stage, we can compare the computed hidden states with the RNN layer output `out_all`. We can observe that computed states match to the first 2 elements of all the RNN layer outputs"
]
},
{
"cell_type": "code",
"execution_count": 177,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"colab_type": "code",
"id": "JdpuayH8foHQ",
"outputId": "1c3efe2a-b107-417c-8018-ae2bcfef91d8"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0.8117, -0.4865],\n",
" [-0.0333, -0.6678],\n",
" [ 0.4834, -0.8639],\n",
" [-0.3773, -0.8821]]], grad_fn=<SliceBackward>)"
]
},
"execution_count": 177,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"out_all[:,:,:2]"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false",
"colab_type": "text",
"id": "ne-5OQ5Tw6b5"
},
"source": [
"### **Computing Outputs - Backward Direction**"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false",
"colab_type": "text",
"id": "upca8GrGsNXR"
},
"source": [
"#### Hidden State 1 - Backward direction"
]
},
{
"cell_type": "code",
"execution_count": 190,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "OOc-joZtrrhe",
"outputId": "0131a779-b4f5-4452-fb46-08b22220875a"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.2727, 0.7177]], grad_fn=<TanhBackward>)"
]
},
"execution_count": 190,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"x = seq[0][-1] # The very last element of the sequence is now treated as the first element in the backward run\n",
"\n",
"# Computing thw hidden state for time = 4\n",
"h4_ = torch.tanh(Tensor(matmul(x,wih_.T) + bih_ + matmul( torch.zeros([1,2]) , whh_.T ) + bhh_)) \n",
"h4_\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false",
"colab_type": "text",
"id": "q-F07utxs_nD"
},
"source": [
"#### Hidden State 2 - Backward direction"
]
},
{
"cell_type": "code",
"execution_count": 195,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "puhmMxodtC0U",
"outputId": "3cd6fa20-ff92-4b33-a0f4-31b08be9a0be"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.1292, 0.3069]], grad_fn=<TanhBackward>)"
]
},
"execution_count": 195,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"x = seq[0][-2] \n",
"\n",
"# Computing thw hidden state for time = 3\n",
"h3_ = torch.tanh(Tensor(matmul(x,wih_.T) + bih_ + matmul( h4_ , whh_.T ) + bhh_)) \n",
"h3_\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false",
"colab_type": "text",
"id": "B98x-5IlvZmZ"
},
"source": [
"#### Hidden State 3 - Backward direction"
]
},
{
"cell_type": "code",
"execution_count": 196,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "aK0ZdtGYtC3R",
"outputId": "700b6640-91ff-4fb8-d722-dd04df1e2932"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.3855, 0.6282]], grad_fn=<TanhBackward>)"
]
},
"execution_count": 196,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"x = seq[0][-3] \n",
"\n",
"# Computing thw hidden state for time = 3\n",
"h2_ = torch.tanh(Tensor(matmul(x,wih_.T) + bih_ + matmul( h3_ , whh_.T ) + bhh_)) \n",
"h2_\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false",
"colab_type": "text",
"id": "4ImcG2Tpvd7r"
},
"source": [
"#### Hidden State 4 - Backward direction"
]
},
{
"cell_type": "code",
"execution_count": 197,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "gjWVsIkSvgdI",
"outputId": "e6002138-13b1-4d0a-b25d-22af6da29771"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.4549, 0.4985]], grad_fn=<TanhBackward>)"
]
},
"execution_count": 197,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"x = seq[0][-4] \n",
"\n",
"# Computing thw hidden state for time = 3\n",
"h1_ = torch.tanh(Tensor(matmul(x,wih_.T) + bih_ + matmul( h2_ , whh_.T ) + bhh_)) \n",
"h1_\n"
]
},
{
"cell_type": "code",
"execution_count": 199,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"colab_type": "code",
"id": "fBFjqx3Ovm8W",
"outputId": "0148b1d6-c4ee-450f-d97d-50ee519e1444"
},
"outputs": [
{
"data": {
"text/plain": [
"[tensor([[0.4549, 0.4985]], grad_fn=<TanhBackward>),\n",
" tensor([[0.3855, 0.6282]], grad_fn=<TanhBackward>),\n",
" tensor([[0.1292, 0.3069]], grad_fn=<TanhBackward>),\n",
" tensor([[0.2727, 0.7177]], grad_fn=<TanhBackward>)]"
]
},
"execution_count": 199,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"output_ = [h1_,h2_,h3_,h4_]\n",
"output_"
]
},
{
"cell_type": "code",
"execution_count": 209,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"colab_type": "code",
"id": "6msSr2Tgrvr3",
"outputId": "5f606d67-746d-433f-afe2-9f60301a602c"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0.4549, 0.4985],\n",
" [0.3855, 0.6282],\n",
" [0.1292, 0.3069],\n",
" [0.2727, 0.7177]]], grad_fn=<SliceBackward>)"
]
},
"execution_count": 209,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"out_all[:,:,2:] #Checking only the 2nd half of the RNN layer output"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false",
"colab_type": "text",
"id": "USl9f4_FwYN1"
},
"source": [
"The final RNN layer output is the concatentation of hidden states from both the forward and backward runs. On doing so, we can compare our manually computed results with the RNN layer output"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false",
"colab": {},
"colab_type": "code",
"id": "thI9zWszv-C3"
},
"outputs": [],
"source": [
"fullOutput = [ torch.cat( (output[i], output_[i]),1) for i in range(4) ]"
]
},
{
"cell_type": "code",
"execution_count": 207,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"colab_type": "code",
"id": "ZnIRYJo4wB7q",
"outputId": "259a1ce5-bc85-452a-9d1d-99ebbb108e52"
},
"outputs": [
{
"data": {
"text/plain": [
"[tensor([[ 0.8117, -0.4865, 0.4549, 0.4985]], grad_fn=<CatBackward>),\n",
" tensor([[-0.0333, -0.6678, 0.3855, 0.6282]], grad_fn=<CatBackward>),\n",
" tensor([[ 0.4834, -0.8639, 0.1292, 0.3069]], grad_fn=<CatBackward>),\n",
" tensor([[-0.3773, -0.8821, 0.2727, 0.7177]], grad_fn=<CatBackward>)]"
]
},
"execution_count": 207,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"fullOutput"
]
},
{
"cell_type": "code",
"execution_count": 208,
"metadata": {
"Collapsed": "false",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"colab_type": "code",
"id": "wqvIReyUwS6J",
"outputId": "1bed521f-da43-4ce3-9d64-3675484be0e2"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0.8117, -0.4865, 0.4549, 0.4985],\n",
" [-0.0333, -0.6678, 0.3855, 0.6282],\n",
" [ 0.4834, -0.8639, 0.1292, 0.3069],\n",
" [-0.3773, -0.8821, 0.2727, 0.7177]]], grad_fn=<TransposeBackward1>)"
]
},
"execution_count": 208,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"out_all"
]
}
],
"metadata": {
"colab": {
"name": "Understanding RNNs.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment