Skip to content

Instantly share code, notes, and snippets.

@SauravMaheshkar
Created September 9, 2021 11:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SauravMaheshkar/168f0817f0cd29dd4048868fb0dd4401 to your computer and use it in GitHub Desktop.
Save SauravMaheshkar/168f0817f0cd29dd4048868fb0dd4401 to your computer and use it in GitHub Desktop.
LSTMs in PyTorch 🔥
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "LSTMs in PyTorch 🔥",
"provenance": [],
"collapsed_sections": [
"HL-CfT0quEz8",
"rltlpuWwuNzg",
"Drpvydbp-GkP",
"nmI_j69k9i-H",
"Wuss5ZGQ9r0x",
"FWRwT28S9y2a",
"_OyrH4ta9038",
"BKAA2rR0-B-3",
"L2q-PMVa_D7d"
],
"authorship_tag": "ABX9TyO7XULsWcMLi7HcAbvQaNZZ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/SauravMaheshkar/168f0817f0cd29dd4048868fb0dd4401/lstms-in-pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5nEHdYcCt-m3"
},
"source": [
"### Author: [@SauravMaheshkar](https://twitter.com/MaheshkarSaurav)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HL-CfT0quEz8"
},
"source": [
"# Packages 📦 and Basic Setup\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rltlpuWwuNzg"
},
"source": [
"## Install Packages"
]
},
{
"cell_type": "code",
"metadata": {
"id": "HxKLQT_ot6rH"
},
"source": [
"%%capture\n",
"\n",
"## Install the latest version of wandb client 🔥🔥\n",
"!pip install -q --upgrade wandb\n",
"\n",
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"from torchtext.legacy import data\n",
"from torch.autograd import Variable\n",
"from torchtext.legacy import datasets\n",
"from torchtext.vocab import Vectors, GloVe"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Drpvydbp-GkP"
},
"source": [
"## Project Configuration using `wandb.config`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "UihlfsH3-JRq"
},
"source": [
"import os\n",
"import wandb\n",
"\n",
"# Paste your api key here\n",
"os.environ[\"WANDB_API_KEY\"] = '...'\n",
"\n",
"# Feel free to change these and experiment !!\n",
"config = wandb.config\n",
"config.learning_rate = 2e-5\n",
"config.batch_size = 32\n",
"config.output_size = 2\n",
"config.hidden_size = 256\n",
"config.embedding_length = 300\n",
"config.epochs = 10"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "nmI_j69k9i-H"
},
"source": [
"# 💿 The Dataset\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Xy_k0At-UGT"
},
"source": [
"In this code cell we use torchtext legacy module to create a Dataset"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1kFBF4DEwwPR",
"outputId": "51a7dc40-8904-4ebf-a565-074ad7e3371b"
},
"source": [
"# Ported from: https://github.com/prakashpandey9/Text-Classification-Pytorch/blob/master/load_data.py\n",
"\n",
"def load_dataset(test_sen=None):\n",
" \n",
" tokenize = lambda x: x.split()\n",
" TEXT = data.Field(sequential=True, tokenize=tokenize, lower=True, include_lengths=True, batch_first=True, fix_length=200)\n",
" LABEL = data.LabelField()\n",
" train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n",
" TEXT.build_vocab(train_data, vectors=GloVe(name='6B', dim=300))\n",
" LABEL.build_vocab(train_data)\n",
"\n",
" word_embeddings = TEXT.vocab.vectors\n",
"\n",
" train_data, valid_data = train_data.split()\n",
" train_iter, valid_iter, test_iter = data.BucketIterator.splits((train_data, valid_data, test_data), \n",
" batch_size=32, \n",
" sort_key=lambda x: len(x.text), \n",
" repeat=False, shuffle=True)\n",
"\n",
" vocab_size = len(TEXT.vocab)\n",
"\n",
" return TEXT, vocab_size, word_embeddings, train_iter, valid_iter, test_iter\n",
"\n",
"TEXT, vocab_size, word_embeddings, train_iter, valid_iter, test_iter = load_dataset()"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"downloading aclImdb_v1.tar.gz\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"aclImdb_v1.tar.gz: 100%|██████████| 84.1M/84.1M [00:01<00:00, 43.0MB/s]\n",
".vector_cache/glove.6B.zip: 862MB [02:40, 5.36MB/s] \n",
"100%|█████████▉| 399999/400000 [00:52<00:00, 7643.86it/s]\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wuss5ZGQ9r0x"
},
"source": [
"# ✍️ Model Architecture\n",
"---"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xY7FMji7wycY"
},
"source": [
"class LSTMClassifier(nn.Module):\n",
"\tdef __init__(self, batch_size, output_size, hidden_size, vocab_size, embedding_length, weights):\n",
"\t\tsuper(LSTMClassifier, self).__init__()\n",
"\t\tself.batch_size = batch_size\n",
"\t\tself.output_size = output_size\n",
"\t\tself.hidden_size = hidden_size\n",
"\t\tself.vocab_size = vocab_size\n",
"\t\tself.embedding_length = embedding_length\n",
"\t\t\n",
"\t\tself.word_embeddings = nn.Embedding(vocab_size, embedding_length)\n",
"\t\tself.word_embeddings.weight = nn.Parameter(weights, requires_grad=False) \n",
"\t\tself.lstm = nn.LSTM(embedding_length, hidden_size) # Our main hero for this tutorial\n",
"\t\tself.label = nn.Linear(hidden_size, output_size)\n",
"\t\t\n",
"\tdef forward(self, input_sentence, batch_size=None):\n",
"\t\tinput = self.word_embeddings(input_sentence) \n",
"\t\tinput = input.permute(1, 0, 2) \n",
"\t\tif batch_size is None:\n",
"\t\t\th_0 = Variable(torch.zeros(1, self.batch_size, self.hidden_size).cuda()) \n",
"\t\t\tc_0 = Variable(torch.zeros(1, self.batch_size, self.hidden_size).cuda()) \n",
"\t\telse:\n",
"\t\t\th_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())\n",
"\t\t\tc_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())\n",
"\t\toutput, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0))\n",
"\t\tfinal_output = self.label(final_hidden_state[-1]) \n",
"\t\t\n",
"\t\treturn final_output"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "FWRwT28S9y2a"
},
"source": [
"# 🧱 + 🏗 = 🏠 Training\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_OyrH4ta9038"
},
"source": [
"## 🥼 Helper Function"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zxlsEhz4w3mT"
},
"source": [
"def clip_gradient(model, clip_value):\n",
" params = list(filter(lambda p: p.grad is not None, model.parameters()))\n",
" for p in params:\n",
" p.grad.data.clamp_(-clip_value, clip_value)\n",
" \n",
"def train_model(model, train_iter, epoch):\n",
" total_epoch_loss = 0\n",
" total_epoch_acc = 0\n",
" model.cuda()\n",
" optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))\n",
" steps = 0\n",
" model.train()\n",
" for idx, batch in enumerate(train_iter):\n",
" text = batch.text[0]\n",
" target = batch.label\n",
" target = torch.autograd.Variable(target).long()\n",
" if torch.cuda.is_available():\n",
" text = text.cuda()\n",
" target = target.cuda()\n",
" if (text.size()[0] is not 32): \n",
" continue\n",
" optim.zero_grad()\n",
" prediction = model(text)\n",
" loss = loss_fn(prediction, target)\n",
" wandb.log({\"Training Loss\": loss.item()})\n",
" num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).float().sum()\n",
" acc = 100.0 * num_corrects/len(batch)\n",
" wandb.log({\"Training Accuracy\": acc.item()})\n",
" loss.backward()\n",
" clip_gradient(model, 1e-1)\n",
" optim.step()\n",
" steps += 1\n",
" \n",
" if steps % 100 == 0:\n",
" print (f'Epoch: {epoch+1}, Idx: {idx+1}, Training Loss: {loss.item():.4f}, Training Accuracy: {acc.item(): .2f}%')\n",
" \n",
" total_epoch_loss += loss.item()\n",
" total_epoch_acc += acc.item()\n",
" \n",
" return total_epoch_loss/len(train_iter), total_epoch_acc/len(train_iter)\n",
"\n",
"def eval_model(model, val_iter):\n",
" total_epoch_loss = 0\n",
" total_epoch_acc = 0\n",
" model.eval()\n",
" with torch.no_grad():\n",
" for idx, batch in enumerate(val_iter):\n",
" text = batch.text[0]\n",
" if (text.size()[0] is not 32):\n",
" continue\n",
" target = batch.label\n",
" target = torch.autograd.Variable(target).long()\n",
" if torch.cuda.is_available():\n",
" text = text.cuda()\n",
" target = target.cuda()\n",
" prediction = model(text)\n",
" loss = loss_fn(prediction, target)\n",
" wandb.log({\"Evaluation Loss\": loss.item()})\n",
" num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).sum()\n",
" acc = 100.0 * num_corrects/len(batch)\n",
" wandb.log({\"Evaluation Accuracy\": acc.item()})\n",
" total_epoch_loss += loss.item()\n",
" total_epoch_acc += acc.item()\n",
"\n",
" return total_epoch_loss/len(val_iter), total_epoch_acc/len(val_iter)"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "BKAA2rR0-B-3"
},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"metadata": {
"id": "n1XaWIg0w6XZ"
},
"source": [
"model = LSTMClassifier(config.batch_size, config.output_size, config.hidden_size, vocab_size, config.embedding_length, word_embeddings)\n",
"loss_fn = F.cross_entropy\n",
"\n",
"# Create a wandb run to log all your metrics\n",
"run = wandb.init(project='...', entity='...', reinit=True)\n",
"\n",
"wandb.watch(model)\n",
"\n",
"for epoch in range(config.epochs):\n",
" train_loss, train_acc = train_model(model, train_iter, epoch)\n",
" val_loss, val_acc = eval_model(model, valid_iter)\n",
" \n",
" print(f'Epoch: {epoch+1:02}, Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, Val. Loss: {val_loss:3f}, Val. Acc: {val_acc:.2f}%')\n",
"\n",
"run.finish()"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment