Created
June 10, 2022 10:23
-
-
Save iamirmasoud/82308f68f930fddf66ed2df6c184606e to your computer and use it in GitHub Desktop.
PyTorch LSTM minimal example: Feeding two sentences to an LSTM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Minimal Example: Feeding 2 Sentences to an LSTM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"import torch\n", | |
"import torch.nn as nn" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Create Example Data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sentences = ['i like fresh bread', 'i hate stale bread']" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Create Vocabulary" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'stale', 'like', 'bread', 'i', 'hate', 'fresh'}\n" | |
] | |
} | |
], | |
"source": [ | |
"vocabulary = set()\n", | |
"for sent in sentences:\n", | |
" for word in sent.split():\n", | |
" vocabulary.add(word)\n", | |
"\n", | |
"print(vocabulary)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Create Indices to Vectorize Words" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3, 'stale': 4, 'like': 5, 'bread': 6, 'i': 7, 'hate': 8, 'fresh': 9}\n" | |
] | |
} | |
], | |
"source": [ | |
"# Most NLP tasks require some basic tokens, e.g., for padding, unknown words, start of sequence, end of sequence\n", | |
"word2index = {'<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3}\n", | |
"\n", | |
"for word in vocabulary:\n", | |
" word2index[word] = len(word2index)\n", | |
" \n", | |
"print(word2index)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Vectorize Sentences" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[[7, 5, 9, 6], [7, 8, 4, 6]]\n", | |
"\n", | |
"[[7 5 9 6]\n", | |
" [7 8 4 6]]\n", | |
"The shape of batch is: (2, 4)\n", | |
"\n", | |
"tensor([[7, 5, 9, 6],\n", | |
" [7, 8, 4, 6]])\n", | |
"The shape of batch is: torch.Size([2, 4])\n" | |
] | |
} | |
], | |
"source": [ | |
"batch = [ [ word2index[word] for word in sent.split() ] for sent in sentences]\n", | |
"\n", | |
"print(batch)\n", | |
"print()\n", | |
"\n", | |
"# Let's make a numpy array out of it\n", | |
"batch = np.array(batch)\n", | |
"\n", | |
"print(batch)\n", | |
"print('The shape of batch is:', batch.shape)\n", | |
"print()\n", | |
"\n", | |
"# Let's make a PyTorch tensor out of it\n", | |
"batch = torch.tensor(batch, dtype=torch.long)\n", | |
"print(batch)\n", | |
"print('The shape of batch is:', batch.shape)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now we have the first step of having our data in the shape `(batch_size, seq_len)`. to feed it into an LSTM, we still need `input_size`." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Embedding Layer" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Define layer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vocab_size = len(word2index) # vocab_size reflects all known words in the index\n", | |
"embed_dim = 10 # Let's assume the word embeddings are of size 10 to keep it simple\n", | |
"\n", | |
"word_embedding_layer = nn.Embedding(vocab_size, embed_dim)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Push batch through layer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"The shape of batch is: torch.Size([2, 4, 10])\n", | |
"\n", | |
"tensor([[[-0.3179, 0.5848, -0.0688, -0.6153, -0.5866, -0.2300, -0.0115,\n", | |
" 1.6450, 2.4355, 0.5729],\n", | |
" [-0.3548, 0.6907, -2.0556, -1.5537, 0.3275, -1.2800, 0.6300,\n", | |
" -0.5206, 1.8570, -0.3607],\n", | |
" [-0.2171, -0.4977, -0.5973, -1.4402, -0.8320, -0.4389, -1.2304,\n", | |
" -0.2994, 0.2303, -0.4284],\n", | |
" [ 1.9364, -1.8846, -0.4562, -0.1178, 0.8705, -0.1335, 0.1754,\n", | |
" -1.7738, -1.0396, 0.1272]],\n", | |
"\n", | |
" [[-0.3179, 0.5848, -0.0688, -0.6153, -0.5866, -0.2300, -0.0115,\n", | |
" 1.6450, 2.4355, 0.5729],\n", | |
" [-1.5821, 0.2974, -1.3868, -0.3796, 1.5824, -0.6330, -0.0880,\n", | |
" 0.2109, -1.1064, -0.6549],\n", | |
" [-1.5234, -0.3309, -0.3816, 0.9839, 1.6590, -1.2978, -0.1062,\n", | |
" 0.0335, -1.5911, -1.8855],\n", | |
" [ 1.9364, -1.8846, -0.4562, -0.1178, 0.8705, -0.1335, 0.1754,\n", | |
" -1.7738, -1.0396, 0.1272]]], grad_fn=<EmbeddingBackward>)\n" | |
] | |
} | |
], | |
"source": [ | |
"batch = word_embedding_layer(batch)\n", | |
"\n", | |
"print('The shape of batch is:', batch.shape)\n", | |
"print()\n", | |
"print(batch)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now we have want we want: `(batch_size, seq_len, input_size)`" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### LSTM Layer (`batch_first=True`)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Define LSTM layer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"hidden_dim = 32 # Let's set the size if the hidden dimension to 32\n", | |
"\n", | |
"lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Push batch through LSTM layer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"The shape of lstm_out is: torch.Size([2, 4, 32])\n", | |
"The shape of h is: torch.Size([1, 2, 32])\n" | |
] | |
} | |
], | |
"source": [ | |
"# Initialize hidden state (here with zeros)\n", | |
"batch_size = batch.shape[0]\n", | |
"(h, c) = (torch.zeros(1, batch_size, hidden_dim), torch.zeros(1, batch_size, hidden_dim)) \n", | |
"\n", | |
"lstm_out, (h, c) = lstm(batch, (h, c))\n", | |
"\n", | |
"print('The shape of lstm_out is:', lstm_out.shape) # (batch_size, seq_len, hidden_dim)\n", | |
"print('The shape of h is:', h.shape) # (num_layers*num_directions, batch_size, hidden_dim)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### LSTM Layer (`batch_first=False`)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Define LSTM layer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"hidden_dim = 32 # Let's set the size if the hidden dimension to 32\n", | |
"\n", | |
"lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=False)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Push batch through LSTM layer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"The shape of lstm_out is: torch.Size([4, 2, 32])\n", | |
"The shape of h is: torch.Size([1, 2, 32])\n" | |
] | |
} | |
], | |
"source": [ | |
"# Initialize hidden state (here with zeros)\n", | |
"batch_size = batch.shape[0]\n", | |
"(h, c) = (torch.zeros(1, batch_size, hidden_dim), torch.zeros(1, batch_size, hidden_dim)) \n", | |
"\n", | |
"# We need to reshape the batch from (batch_size, seq_len, input_size) to (seq_len, batch_size, input_size)\n", | |
"batch = batch.transpose(0,1)\n", | |
"\n", | |
"lstm_out, (h, c) = lstm(batch, (h, c))\n", | |
"\n", | |
"print('The shape of lstm_out is:', lstm_out.shape) # (seq_len, batch_size, hidden_dim)\n", | |
"print('The shape of h is:', h.shape) # (num_layers*num_directions, batch_size, hidden_dim)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda env:py36]", | |
"language": "python", | |
"name": "conda-env-py36-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment