Created
September 9, 2021 11:58
-
-
Save SauravMaheshkar/168f0817f0cd29dd4048868fb0dd4401 to your computer and use it in GitHub Desktop.
LSTMs in PyTorch 🔥
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "LSTMs in PyTorch 🔥", | |
"provenance": [], | |
"collapsed_sections": [ | |
"HL-CfT0quEz8", | |
"rltlpuWwuNzg", | |
"Drpvydbp-GkP", | |
"nmI_j69k9i-H", | |
"Wuss5ZGQ9r0x", | |
"FWRwT28S9y2a", | |
"_OyrH4ta9038", | |
"BKAA2rR0-B-3", | |
"L2q-PMVa_D7d" | |
], | |
"authorship_tag": "ABX9TyO7XULsWcMLi7HcAbvQaNZZ", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/SauravMaheshkar/168f0817f0cd29dd4048868fb0dd4401/lstms-in-pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "5nEHdYcCt-m3" | |
}, | |
"source": [ | |
"### Author: [@SauravMaheshkar](https://twitter.com/MaheshkarSaurav)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "HL-CfT0quEz8" | |
}, | |
"source": [ | |
"# Packages 📦 and Basic Setup\n", | |
"\n", | |
"---" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "rltlpuWwuNzg" | |
}, | |
"source": [ | |
"## Install Packages" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "HxKLQT_ot6rH" | |
}, | |
"source": [ | |
"%%capture\n", | |
"\n", | |
"## Install the latest version of wandb client 🔥🔥\n", | |
"!pip install -q --upgrade wandb\n", | |
"\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"import torch.nn.functional as F\n", | |
"from torchtext.legacy import data\n", | |
"from torch.autograd import Variable\n", | |
"from torchtext.legacy import datasets\n", | |
"from torchtext.vocab import Vectors, GloVe" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Drpvydbp-GkP" | |
}, | |
"source": [ | |
"## Project Configuration using `wandb.config`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "UihlfsH3-JRq" | |
}, | |
"source": [ | |
"import os\n", | |
"import wandb\n", | |
"\n", | |
"# Paste your api key here\n", | |
"os.environ[\"WANDB_API_KEY\"] = '...'\n", | |
"\n", | |
"# Feel free to change these and experiment !!\n", | |
"config = wandb.config\n", | |
"config.learning_rate = 2e-5\n", | |
"config.batch_size = 32\n", | |
"config.output_size = 2\n", | |
"config.hidden_size = 256\n", | |
"config.embedding_length = 300\n", | |
"config.epochs = 10" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "nmI_j69k9i-H" | |
}, | |
"source": [ | |
"# 💿 The Dataset\n", | |
"\n", | |
"---" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "9Xy_k0At-UGT" | |
}, | |
"source": [ | |
"In this code cell we use torchtext legacy module to create a Dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "1kFBF4DEwwPR", | |
"outputId": "51a7dc40-8904-4ebf-a565-074ad7e3371b" | |
}, | |
"source": [ | |
"# Ported from: https://github.com/prakashpandey9/Text-Classification-Pytorch/blob/master/load_data.py\n", | |
"\n", | |
"def load_dataset(test_sen=None):\n", | |
" \n", | |
" tokenize = lambda x: x.split()\n", | |
" TEXT = data.Field(sequential=True, tokenize=tokenize, lower=True, include_lengths=True, batch_first=True, fix_length=200)\n", | |
" LABEL = data.LabelField()\n", | |
" train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n", | |
" TEXT.build_vocab(train_data, vectors=GloVe(name='6B', dim=300))\n", | |
" LABEL.build_vocab(train_data)\n", | |
"\n", | |
" word_embeddings = TEXT.vocab.vectors\n", | |
"\n", | |
" train_data, valid_data = train_data.split()\n", | |
" train_iter, valid_iter, test_iter = data.BucketIterator.splits((train_data, valid_data, test_data), \n", | |
" batch_size=32, \n", | |
" sort_key=lambda x: len(x.text), \n", | |
" repeat=False, shuffle=True)\n", | |
"\n", | |
" vocab_size = len(TEXT.vocab)\n", | |
"\n", | |
" return TEXT, vocab_size, word_embeddings, train_iter, valid_iter, test_iter\n", | |
"\n", | |
"TEXT, vocab_size, word_embeddings, train_iter, valid_iter, test_iter = load_dataset()" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"downloading aclImdb_v1.tar.gz\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"aclImdb_v1.tar.gz: 100%|██████████| 84.1M/84.1M [00:01<00:00, 43.0MB/s]\n", | |
".vector_cache/glove.6B.zip: 862MB [02:40, 5.36MB/s] \n", | |
"100%|█████████▉| 399999/400000 [00:52<00:00, 7643.86it/s]\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Wuss5ZGQ9r0x" | |
}, | |
"source": [ | |
"# ✍️ Model Architecture\n", | |
"---" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xY7FMji7wycY" | |
}, | |
"source": [ | |
"class LSTMClassifier(nn.Module):\n", | |
"\tdef __init__(self, batch_size, output_size, hidden_size, vocab_size, embedding_length, weights):\n", | |
"\t\tsuper(LSTMClassifier, self).__init__()\n", | |
"\t\tself.batch_size = batch_size\n", | |
"\t\tself.output_size = output_size\n", | |
"\t\tself.hidden_size = hidden_size\n", | |
"\t\tself.vocab_size = vocab_size\n", | |
"\t\tself.embedding_length = embedding_length\n", | |
"\t\t\n", | |
"\t\tself.word_embeddings = nn.Embedding(vocab_size, embedding_length)\n", | |
"\t\tself.word_embeddings.weight = nn.Parameter(weights, requires_grad=False) \n", | |
"\t\tself.lstm = nn.LSTM(embedding_length, hidden_size) # Our main hero for this tutorial\n", | |
"\t\tself.label = nn.Linear(hidden_size, output_size)\n", | |
"\t\t\n", | |
"\tdef forward(self, input_sentence, batch_size=None):\n", | |
"\t\tinput = self.word_embeddings(input_sentence) \n", | |
"\t\tinput = input.permute(1, 0, 2) \n", | |
"\t\tif batch_size is None:\n", | |
"\t\t\th_0 = Variable(torch.zeros(1, self.batch_size, self.hidden_size).cuda()) \n", | |
"\t\t\tc_0 = Variable(torch.zeros(1, self.batch_size, self.hidden_size).cuda()) \n", | |
"\t\telse:\n", | |
"\t\t\th_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())\n", | |
"\t\t\tc_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())\n", | |
"\t\toutput, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0))\n", | |
"\t\tfinal_output = self.label(final_hidden_state[-1]) \n", | |
"\t\t\n", | |
"\t\treturn final_output" | |
], | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "FWRwT28S9y2a" | |
}, | |
"source": [ | |
"# 🧱 + 🏗 = 🏠 Training\n", | |
"\n", | |
"---" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "_OyrH4ta9038" | |
}, | |
"source": [ | |
"## 🥼 Helper Function" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zxlsEhz4w3mT" | |
}, | |
"source": [ | |
"def clip_gradient(model, clip_value):\n", | |
" params = list(filter(lambda p: p.grad is not None, model.parameters()))\n", | |
" for p in params:\n", | |
" p.grad.data.clamp_(-clip_value, clip_value)\n", | |
" \n", | |
"def train_model(model, train_iter, epoch):\n", | |
" total_epoch_loss = 0\n", | |
" total_epoch_acc = 0\n", | |
" model.cuda()\n", | |
" optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))\n", | |
" steps = 0\n", | |
" model.train()\n", | |
" for idx, batch in enumerate(train_iter):\n", | |
" text = batch.text[0]\n", | |
" target = batch.label\n", | |
" target = torch.autograd.Variable(target).long()\n", | |
" if torch.cuda.is_available():\n", | |
" text = text.cuda()\n", | |
" target = target.cuda()\n", | |
" if (text.size()[0] is not 32): \n", | |
" continue\n", | |
" optim.zero_grad()\n", | |
" prediction = model(text)\n", | |
" loss = loss_fn(prediction, target)\n", | |
" wandb.log({\"Training Loss\": loss.item()})\n", | |
" num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).float().sum()\n", | |
" acc = 100.0 * num_corrects/len(batch)\n", | |
" wandb.log({\"Training Accuracy\": acc.item()})\n", | |
" loss.backward()\n", | |
" clip_gradient(model, 1e-1)\n", | |
" optim.step()\n", | |
" steps += 1\n", | |
" \n", | |
" if steps % 100 == 0:\n", | |
" print (f'Epoch: {epoch+1}, Idx: {idx+1}, Training Loss: {loss.item():.4f}, Training Accuracy: {acc.item(): .2f}%')\n", | |
" \n", | |
" total_epoch_loss += loss.item()\n", | |
" total_epoch_acc += acc.item()\n", | |
" \n", | |
" return total_epoch_loss/len(train_iter), total_epoch_acc/len(train_iter)\n", | |
"\n", | |
"def eval_model(model, val_iter):\n", | |
" total_epoch_loss = 0\n", | |
" total_epoch_acc = 0\n", | |
" model.eval()\n", | |
" with torch.no_grad():\n", | |
" for idx, batch in enumerate(val_iter):\n", | |
" text = batch.text[0]\n", | |
" if (text.size()[0] is not 32):\n", | |
" continue\n", | |
" target = batch.label\n", | |
" target = torch.autograd.Variable(target).long()\n", | |
" if torch.cuda.is_available():\n", | |
" text = text.cuda()\n", | |
" target = target.cuda()\n", | |
" prediction = model(text)\n", | |
" loss = loss_fn(prediction, target)\n", | |
" wandb.log({\"Evaluation Loss\": loss.item()})\n", | |
" num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).sum()\n", | |
" acc = 100.0 * num_corrects/len(batch)\n", | |
" wandb.log({\"Evaluation Accuracy\": acc.item()})\n", | |
" total_epoch_loss += loss.item()\n", | |
" total_epoch_acc += acc.item()\n", | |
"\n", | |
" return total_epoch_loss/len(val_iter), total_epoch_acc/len(val_iter)" | |
], | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "BKAA2rR0-B-3" | |
}, | |
"source": [ | |
"## Training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "n1XaWIg0w6XZ" | |
}, | |
"source": [ | |
"model = LSTMClassifier(config.batch_size, config.output_size, config.hidden_size, vocab_size, config.embedding_length, word_embeddings)\n", | |
"loss_fn = F.cross_entropy\n", | |
"\n", | |
"# Create a wandb run to log all your metrics\n", | |
"run = wandb.init(project='...', entity='...', reinit=True)\n", | |
"\n", | |
"wandb.watch(model)\n", | |
"\n", | |
"for epoch in range(config.epochs):\n", | |
" train_loss, train_acc = train_model(model, train_iter, epoch)\n", | |
" val_loss, val_acc = eval_model(model, valid_iter)\n", | |
" \n", | |
" print(f'Epoch: {epoch+1:02}, Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, Val. Loss: {val_loss:3f}, Val. Acc: {val_acc:.2f}%')\n", | |
"\n", | |
"run.finish()" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment