Skip to content

Instantly share code, notes, and snippets.

@JAEarly
Last active January 13, 2023 20:12
Show Gist options
  • Save JAEarly/a1b1cd94ce6ec2864428d8d3350bfa1f to your computer and use it in GitHub Desktop.
Save JAEarly/a1b1cd94ce6ec2864428d8d3350bfa1f to your computer and use it in GitHub Desktop.
A Python Notebook on how to use bi directional LSTMs in PyTorch.
Display the source blob
Display the rendered blob
Raw
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Understanding_BiLSTM_Outputs.ipynb","provenance":[],"authorship_tag":"ABX9TyMZJlcd+EyD6r7okEeSGNVB"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Understanding Bi-Directional LSTM Outputs in PyTorch\n","\n","This is a really quick example of the outputs of a PyTorch Bi-Directional LSTM.\n","\n","First, let's create a random batch of data to serve as an example:"],"metadata":{"id":"lj3AW_oSGGKy"}},{"cell_type":"code","source":["import numpy as np\n","import torch\n","from torch import nn\n","\n","# Batch size of one just to keep it simple\n","n_batches = 1\n","# Six tokens in our batch\n","n_tokens = 6\n","# Three features per token\n","n_features = 3\n","\n","# Generate a random batch\n","rand_batch = torch.rand((n_batches, n_tokens, n_features))\n","rand_batch"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"5I1rXV3KnONd","executionInfo":{"status":"ok","timestamp":1646391384770,"user_tz":0,"elapsed":230,"user":{"displayName":"Joseph Early","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjB-IA8aPPgVEv4uVdJngeZShL-XF_Z0dxEfm_f=s64","userId":"05624770450897554531"}},"outputId":"bc44f563-a15f-430a-d167-4760a3115dd1"},"execution_count":14,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[[0.0118, 0.5465, 0.6403],\n"," [0.7693, 0.0807, 0.0816],\n"," [0.8544, 0.3673, 0.2540],\n"," [0.3607, 0.0892, 0.9755],\n"," [0.5160, 0.7301, 0.1910],\n"," [0.6655, 0.8530, 0.1662]]])"]},"metadata":{},"execution_count":14}]},{"cell_type":"markdown","source":["In the `rand_batch` output above, each row is a token and each column is a feature.\n","\n","Now let's create our Bi-Directional LSTM:"],"metadata":{"id":"y4hBaqBPGtnJ"}},{"cell_type":"code","source":["# Create our Bi-Drectional LSTM with a hidden layer size of four, and two layers\n","d_hid = 4\n","n_layers = 2\n","bi_lstm = nn.LSTM(input_size=n_features, hidden_size=d_hid, num_layers=n_layers,\n"," bidirectional=True, batch_first=True)\n","bi_lstm"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"NCo6RqOVoKqh","executionInfo":{"status":"ok","timestamp":1646391385032,"user_tz":0,"elapsed":10,"user":{"displayName":"Joseph Early","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjB-IA8aPPgVEv4uVdJngeZShL-XF_Z0dxEfm_f=s64","userId":"05624770450897554531"}},"outputId":"6fc5b308-964d-4f09-a855-93ee90a89558"},"execution_count":15,"outputs":[{"output_type":"execute_result","data":{"text/plain":["LSTM(3, 4, num_layers=2, batch_first=True, bidirectional=True)"]},"metadata":{},"execution_count":15}]},{"cell_type":"markdown","source":["We'll now pass the example batch through the model.\n","\n","This produces three outputs: `out`, `h_n`, `c_n`. From the PyTorch docs:\n","* `out` is the the output features from the last layer of the LSTM, for each token\n","* `h_n` contains the final hidden state for each layer in the model\n","* `c_n` contains the final cell state for each layer in the model\n","\n","So `out` gives us information \"across tokens\", whereas `h_n` and `c_n` give us information \"across layers\"."],"metadata":{"id":"7PNaLSL5HGTq"}},{"cell_type":"code","source":["# Pass our random batch through our model\n","out, (h_n, c_n) = bi_lstm(rand_batch)\n","print('out shape:', out.shape)\n","print('h_n shape:', h_n.shape)\n","print('c_n shape:', c_n.shape)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"kGoFCbhZoXWT","executionInfo":{"status":"ok","timestamp":1646391385033,"user_tz":0,"elapsed":10,"user":{"displayName":"Joseph Early","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjB-IA8aPPgVEv4uVdJngeZShL-XF_Z0dxEfm_f=s64","userId":"05624770450897554531"}},"outputId":"f1e221de-7113-4281-de5d-b021dd52c0bf"},"execution_count":16,"outputs":[{"output_type":"stream","name":"stdout","text":["out shape: torch.Size([1, 6, 8])\n","h_n shape: torch.Size([4, 1, 4])\n","c_n shape: torch.Size([4, 1, 4])\n"]}]},{"cell_type":"markdown","source":["So what is contained in these outputs?\n","* `out` is a tensor of shape: `(n_batches, n_tokens, d_hid * 2)`\n"," * So, for each token, we have a tensor of length twice the length of our model's hidden dimension. It is twice the length as it contains information from both the forward and backward dimensions. This concatenated tensor is final layer output for each direction (forward and backward).\n"," * When we created our model, we specified `batch_first=True`, otherwise this tensor would be `(n_tokens, n_batches, d_hid * 2)`. Having it batch first just makes me sense to me.\n","* `h_n` is a tensor of shape `(n_layers * 2, n_batches, d_hid)`.\n"," * For each layer in our model, we are given its output, which is a tensor of length `d_hid`. Our model consists of `n_layers * 2` layers, as we have the same number of layers (`n_layers`) for both the forward and backward direction.\n"," * This tensor always has `n_batches` in the second dimension, regardless of what we set `batch_first` to.\n"," * It is possible that the layer outputs are not equal to `d_hid` if an output projection is used, i.e., a non-zero value of the model argument `proj_size`. However, for our model, we kept `proj_size` equal to its default value of zero, so our layer outputs remain at size `d_hid`.\n","* `c_n` is a tensor of shape `(n_layers * 2, n_batches, d_hid)`.\n"," * This gives the cell state of each layer, in the exact same way that `h_n` gives the output state of each layer.\n"," * The cell state outputs are always equal to `d_hid`, regardless of the value of `proj_size`.\n","\n","Having both the forward and backward directions contained in the same tensors is a bit annoying, so lets separate the directions out:"],"metadata":{"id":"14T-KRzPa0M2"}},{"cell_type":"code","source":["# Split out between the forward and backward directions\n","out_split = out.view(n_batches, n_tokens, 2, d_hid)\n","out_forward = out_split[:, :, 0, :]\n","out_backward = out_split[:, :, 1, :]\n","\n","# Split h_n between the forward and backward directions\n","h_n_forward = h_n[::2, :, :]\n","h_n_backward = h_n[1::2, :, :]\n","print(' Out split shape:', out_split.shape)\n","print(' Out forward shape:', out_forward.shape)\n","print('Out backward shape:', out_backward.shape)\n","print(' h_n forward shape:', h_n_forward.shape)\n","print('h_n backward shape:', h_n_backward.shape)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"SONSJX3PpDJD","executionInfo":{"status":"ok","timestamp":1646391385033,"user_tz":0,"elapsed":9,"user":{"displayName":"Joseph Early","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjB-IA8aPPgVEv4uVdJngeZShL-XF_Z0dxEfm_f=s64","userId":"05624770450897554531"}},"outputId":"5ffd171a-2b41-4e75-86bd-7d635564112a"},"execution_count":17,"outputs":[{"output_type":"stream","name":"stdout","text":[" Out split shape: torch.Size([1, 6, 2, 4])\n"," Out forward shape: torch.Size([1, 6, 4])\n","Out backward shape: torch.Size([1, 6, 4])\n"," h_n forward shape: torch.Size([2, 1, 4])\n","h_n backward shape: torch.Size([2, 1, 4])\n"]}]},{"cell_type":"markdown","source":["So what did we do here?\n","\n","For `out`, the forward and backward information is mixed in the output tensors for each token. We can split these tensors in two using `out.view...`, such that the `2 * d_hid` tensors become two tensors of length `d_hid`. The documentation tells us that the first half of the tensor contains the forward direction, and the second half contains the backward direction. In the end, we're left with two tensors of shape `(n_batches, n_tokens, d_hid)`, where one is the token outputs in the forward direction, and the other is the token outputs in the backward direction.\n","\n","For `h_n` (and by extension `c_n`), the output tensor actually alternates directions, i.e., \"even\" rows contain the forward direction, and the \"odd\" rows contain the backward direction. Therefore, if we split based on alternate rows, we're left with two tensors of shape `(n_layers, n_batches, d_hid)`, where one is the hidden layer outputs for the forward direction, and the other is the hidden layer outputs for the backward direction.\n","\n","If `out` contains the final layer outputs for each token, and `h_n` contains all the layer outputs just for the final token, is there an overlap?"],"metadata":{"id":"-DQKkKnDkvGO"}},{"cell_type":"code","source":["# Compare out and h_n\n","# Indexing is weird here because out_backward is in the reverse direction\n","# And batch_first doesn't change the direction of h_n\n","print(' out_forward final', out_forward[:, -1, :])\n","print(' h_n_forward final', h_n_forward[-1, :, :])\n","print('out_backward final', out_backward[:, 0, :])\n","print('h_n_backward final', h_n_backward[-1, :, :])"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"OaOTsbNH0N2m","executionInfo":{"status":"ok","timestamp":1646392088250,"user_tz":0,"elapsed":200,"user":{"displayName":"Joseph Early","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjB-IA8aPPgVEv4uVdJngeZShL-XF_Z0dxEfm_f=s64","userId":"05624770450897554531"}},"outputId":"816b0ba9-0cf0-43e3-a98f-4fbbdd27f80a"},"execution_count":20,"outputs":[{"output_type":"stream","name":"stdout","text":[" out_forward final tensor([[-0.3253, 0.1898, 0.2064, -0.0506]], grad_fn=<SliceBackward0>)\n"," h_n_forward final tensor([[-0.3253, 0.1898, 0.2064, -0.0506]], grad_fn=<SliceBackward0>)\n","out_backward final tensor([[-0.0491, 0.2519, -0.0656, -0.2746]], grad_fn=<SliceBackward0>)\n","h_n_backward final tensor([[-0.0491, 0.2519, -0.0656, -0.2746]], grad_fn=<SliceBackward0>)\n"]}]},{"cell_type":"markdown","source":["Yes!\n","\n","If we compare the last entry in `out_forward` with the last entry in `h_n_forward`, we can see they're the same. This makes sense; `out_forward` is the last forward hidden layer output for every token, and `h_n_forward` is the hidden layer output of the last token for every forward layer, so the final entry of both is the last forward hidden layer output for the last token.\n","\n","We can do a similar thing for `out_backward` and `h_n_backward`, except here we want to take the first entry in `out_backward` as it processes the tokens in the reverse order (e.g., for a sentence \"The quick brown fox...\", the last token in the reverse direction is \"the\").\n","\n","What do we use for classification? Well, we want the final layer output for both the forward and backward direction, such that the entire sequence has been processed in both directions, i.e., the final element of `out_forward` and the first element of `out_backward`:"],"metadata":{"id":"_sF8gu-En81Z"}},{"cell_type":"code","source":["# Create final sequence representation\n","final_repr = torch.cat([out_forward[:, -1, :], out_backward[:, 0, :]], dim=1)\n","print(' Final repr shape:', final_repr.shape)\n","print(' Final repr:', final_repr.data)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Q_qx0RRo1qhv","executionInfo":{"status":"ok","timestamp":1646394837433,"user_tz":0,"elapsed":199,"user":{"displayName":"Joseph Early","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjB-IA8aPPgVEv4uVdJngeZShL-XF_Z0dxEfm_f=s64","userId":"05624770450897554531"}},"outputId":"4983c76b-6900-44ac-9ed6-23e60185d6ed"},"execution_count":27,"outputs":[{"output_type":"stream","name":"stdout","text":[" Final repr shape: torch.Size([1, 8])\n"," Final repr: tensor([[-0.3253, 0.1898, 0.2064, -0.0506, -0.0491, 0.2519, -0.0656, -0.2746]])\n"]}]},{"cell_type":"markdown","source":["We now know that this is equivalent to some elements of `h_n`, so here is a one-liner to get the output, independently of the number of layers:"],"metadata":{"id":"7flvu-rqxh6y"}},{"cell_type":"code","source":["# One-liner to get final sequence representation\n","final_repr = torch.cat([h_n[-2, :, :], h_n[-1, :, :]], dim=1)\n","print(' Final repr shape:', final_repr.shape)\n","print(' Final repr:', final_repr.data)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"lx0V_Vz0xxXx","executionInfo":{"status":"ok","timestamp":1646394839573,"user_tz":0,"elapsed":187,"user":{"displayName":"Joseph Early","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjB-IA8aPPgVEv4uVdJngeZShL-XF_Z0dxEfm_f=s64","userId":"05624770450897554531"}},"outputId":"1ee82bda-94ba-45ad-8e44-2e9702929553"},"execution_count":28,"outputs":[{"output_type":"stream","name":"stdout","text":[" Final repr shape: torch.Size([1, 8])\n"," Final repr: tensor([[-0.3253, 0.1898, 0.2064, -0.0506, -0.0491, 0.2519, -0.0656, -0.2746]])\n"]}]},{"cell_type":"markdown","source":["And with that, you now hopefully have a better idea of what the outputs of a Bi-Directional in PyTorch mean!"],"metadata":{"id":"_r9YEFMFzjeI"}}]}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment