Skip to content

Instantly share code, notes, and snippets.

@SharanSMenon
Created February 5, 2020 16:24
Show Gist options
  • Save SharanSMenon/89137be312ad3bb2d9402ed1659a42ca to your computer and use it in GitHub Desktop.
Save SharanSMenon/89137be312ad3bb2d9402ed1659a42ca to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "question-classification-cnn.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "bFGMr3CQGFJz",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch\n",
"from torchtext import data\n",
"from torchtext import datasets\n",
"import random\n",
"\n",
"SEED = 1234\n",
"\n",
"torch.manual_seed(SEED)\n",
"torch.backends.cudnn.deterministic = True\n",
"\n",
"TEXT = data.Field(tokenize = 'spacy')\n",
"LABEL = data.LabelField()\n",
"\n",
"train_data, test_data = datasets.TREC.splits(TEXT, LABEL, fine_grained=True)\n",
"\n",
"train_data, valid_data = train_data.split(random_state = random.seed(SEED))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "oDMpsS7oGH2V",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "c8814a4d-15fb-4b9f-a5d8-d8243096c3d8"
},
"source": [
"vars(train_data[-1])"
],
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'label': 'DESC:def', 'text': ['What', 'is', 'a', 'Cartesian', 'Diver', '?']}"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "laidwNkvGRwX",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"outputId": "bc61e40f-82c2-4f3a-f091-07e61f507608"
},
"source": [
"MAX_VOCAB_SIZE = 45_000\n",
"\n",
"TEXT.build_vocab(train_data, \n",
" max_size = MAX_VOCAB_SIZE, \n",
" vectors = \"glove.6B.100d\", \n",
" unk_init = torch.Tensor.normal_)\n",
"\n",
"LABEL.build_vocab(train_data)"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
".vector_cache/glove.6B.zip: 862MB [06:32, 2.20MB/s] \n",
"100%|█████████▉| 398009/400000 [00:16<00:00, 24050.78it/s]"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "07mXbobyGghA",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 54
},
"outputId": "e86b6e4b-a093-4015-cd66-0915b7e75bfd"
},
"source": [
"print(LABEL.vocab.stoi)"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"defaultdict(<function _default_unk_index at 0x7f7fb0960158>, {'HUM:ind': 0, 'LOC:other': 1, 'DESC:def': 2, 'NUM:count': 3, 'DESC:manner': 4, 'DESC:desc': 5, 'NUM:date': 6, 'DESC:reason': 7, 'HUM:gr': 8, 'ENTY:other': 9, 'ENTY:cremat': 10, 'LOC:country': 11, 'LOC:city': 12, 'ENTY:animal': 13, 'ENTY:dismed': 14, 'ENTY:food': 15, 'ENTY:termeq': 16, 'ABBR:exp': 17, 'NUM:money': 18, 'NUM:period': 19, 'LOC:state': 20, 'ENTY:event': 21, 'ENTY:sport': 22, 'HUM:desc': 23, 'NUM:other': 24, 'ENTY:product': 25, 'ENTY:color': 26, 'ENTY:techmeth': 27, 'ENTY:substance': 28, 'ENTY:word': 29, 'ENTY:veh': 30, 'NUM:dist': 31, 'HUM:title': 32, 'NUM:perc': 33, 'LOC:mount': 34, 'ENTY:body': 35, 'ABBR:abb': 36, 'ENTY:lang': 37, 'ENTY:instru': 38, 'ENTY:plant': 39, 'NUM:code': 40, 'NUM:temp': 41, 'NUM:volsize': 42, 'NUM:weight': 43, 'ENTY:letter': 44, 'ENTY:symbol': 45, 'ENTY:religion': 46, 'NUM:ord': 47, 'NUM:speed': 48, 'ENTY:currency': 49})\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "gKQbNTskIQGp",
"colab_type": "code",
"colab": {}
},
"source": [
"BATCH_SIZE = 64\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
" (train_data, valid_data, test_data), \n",
" batch_size = BATCH_SIZE, \n",
" device = device)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "PjCKXbh5IW-D",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "a41eb153-a353-494d-b2f3-e4133b6b4c77"
},
"source": [
"device"
],
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "FjdHxNeGIX3E",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ZqBeymR9IaKS",
"colab_type": "code",
"colab": {}
},
"source": [
"class CNN(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, \n",
" dropout, pad_idx): \n",
" super().__init__() \n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim) \n",
" self.convs = nn.ModuleList([\n",
" nn.Conv2d(in_channels = 1, \n",
" out_channels = n_filters, \n",
" kernel_size = (fs, embedding_dim)) \n",
" for fs in filter_sizes\n",
" ]) \n",
" self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim) \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, text):\n",
" text = text.permute(1, 0)\n",
" embedded = self.embedding(text)\n",
" embedded = embedded.unsqueeze(1)\n",
" conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]\n",
" pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n",
" cat = self.dropout(torch.cat(pooled, dim = 1))\n",
" return self.fc(cat)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "YLUvwoEmIkZ4",
"colab_type": "code",
"colab": {}
},
"source": [
"INPUT_DIM = len(TEXT.vocab)\n",
"EMBEDDING_DIM = 100\n",
"N_FILTERS = 100\n",
"FILTER_SIZES = [2,3,4]\n",
"OUTPUT_DIM = len(LABEL.vocab)\n",
"DROPOUT = 0.5\n",
"PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n",
"\n",
"model = CNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5Asdh-TZImgX",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 187
},
"outputId": "0cb48953-4a8c-4051-8058-ffec3d6e4fdf"
},
"source": [
"print(model)"
],
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"text": [
"CNN(\n",
" (embedding): Embedding(7503, 100)\n",
" (convs): ModuleList(\n",
" (0): Conv2d(1, 100, kernel_size=(2, 100), stride=(1, 1))\n",
" (1): Conv2d(1, 100, kernel_size=(3, 100), stride=(1, 1))\n",
" (2): Conv2d(1, 100, kernel_size=(4, 100), stride=(1, 1))\n",
" )\n",
" (fc): Linear(in_features=300, out_features=50, bias=True)\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
")\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "A6YCAJKmIn5i",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "75060a64-db33-4a6f-ddda-7f4e0c1520e0"
},
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": [
"The model has 855,650 trainable parameters\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Q3OmgQkzIpuq",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 136
},
"outputId": "0a84a28c-4188-47e1-83ff-1e89de0c4133"
},
"source": [
"pretrained_embeddings = TEXT.vocab.vectors\n",
"\n",
"model.embedding.weight.data.copy_(pretrained_embeddings)"
],
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[-0.1117, -0.4966, 0.1631, ..., 1.2647, -0.2753, -0.1325],\n",
" [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n",
" [ 0.1638, 0.6046, 1.0789, ..., -0.3140, 0.1844, 0.3624],\n",
" ...,\n",
" [-0.3110, -0.3398, 1.0308, ..., 0.5317, 0.2836, -0.0640],\n",
" [ 0.0091, 0.2810, 0.7356, ..., -0.7508, 0.8967, -0.7631],\n",
" [ 0.4306, 1.2011, 0.0873, ..., 0.8817, 0.3722, 0.3458]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZmzT0-vNIrga",
"colab_type": "code",
"colab": {}
},
"source": [
"UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n",
"\n",
"model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n",
"model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "miLY5rvXIs_s",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch.optim as optim\n",
"\n",
"optimizer = optim.Adam(model.parameters())\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"model = model.to(device)\n",
"criterion = criterion.to(device)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "dC5vB7DYIuyV",
"colab_type": "code",
"colab": {}
},
"source": [
"def categorical_accuracy(preds, y):\n",
" max_preds = preds.argmax(dim = 1, keepdim = True)\n",
" correct = max_preds.squeeze(1).eq(y)\n",
" return correct.sum() / torch.FloatTensor([y.shape[0]])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "r_ING1JEIyIx",
"colab_type": "code",
"colab": {}
},
"source": [
"def train(model, iterator, optimizer, criterion):\n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
" model.train()\n",
" for batch in iterator:\n",
" optimizer.zero_grad()\n",
" predictions = model(batch.text)\n",
" loss = criterion(predictions, batch.label)\n",
" acc = categorical_accuracy(predictions, batch.label)\n",
" loss.backward()\n",
" optimizer.step()\n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "TtfacWc-I4D4",
"colab_type": "code",
"colab": {}
},
"source": [
"def evaluate(model, iterator, criterion):\n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
" model.eval()\n",
" with torch.no_grad():\n",
" for batch in iterator:\n",
" predictions = model(batch.text)\n",
" loss = criterion(predictions, batch.label)\n",
" acc = categorical_accuracy(predictions, batch.label)\n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "s1gqEvvjI9OY",
"colab_type": "code",
"colab": {}
},
"source": [
"import time\n",
"\n",
"def epoch_time(start_time, end_time):\n",
" elapsed_time = end_time - start_time\n",
" elapsed_mins = int(elapsed_time / 60)\n",
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
" return elapsed_mins, elapsed_secs"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "dlTp7h7PI-yw",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 170
},
"outputId": "68c7b04c-9402-4d3f-a018-c6c31cbc5c1a"
},
"source": [
"N_EPOCHS = 18\n",
"\n",
"best_valid_loss = float('inf')\n",
"\n",
"for epoch in range(N_EPOCHS):\n",
"\n",
" start_time = time.time()\n",
" \n",
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n",
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n",
" \n",
" end_time = time.time()\n",
"\n",
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
" \n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), 'tut5-model.pt')\n",
" \n",
" print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
" print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')"
],
"execution_count": 24,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch: 01 | Epoch Time: 0m 0s\n",
"\tTrain Loss: 0.166 | Train Acc: 96.91%\n",
"\t Val. Loss: 0.924 | Val. Acc: 76.97%\n",
"Epoch: 02 | Epoch Time: 0m 0s\n",
"\tTrain Loss: 0.138 | Train Acc: 97.78%\n",
"\t Val. Loss: 0.926 | Val. Acc: 76.14%\n",
"Epoch: 03 | Epoch Time: 0m 0s\n",
"\tTrain Loss: 0.122 | Train Acc: 97.84%\n",
"\t Val. Loss: 0.918 | Val. Acc: 76.74%\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "qV4_NbJ_JCn6",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "3d954780-057d-41e5-ffd9-eefaf50904ef"
},
"source": [
"model.load_state_dict(torch.load('tut5-model.pt'))\n",
"\n",
"test_loss, test_acc = evaluate(model, test_iterator, criterion)\n",
"\n",
"print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
],
"execution_count": 25,
"outputs": [
{
"output_type": "stream",
"text": [
"Test Loss: 0.888 | Test Acc: 75.78%\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "xdjKc2K6JSLl",
"colab_type": "code",
"colab": {}
},
"source": [
"import spacy\n",
"nlp = spacy.load('en')\n",
"\n",
"def predict_class(model, sentence, min_len = 4):\n",
" model.eval()\n",
" tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n",
" if len(tokenized) < min_len:\n",
" tokenized += ['<pad>'] * (min_len - len(tokenized))\n",
" indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
" tensor = torch.LongTensor(indexed).to(device)\n",
" tensor = tensor.unsqueeze(1)\n",
" preds = model(tensor)\n",
" max_preds = preds.argmax(dim = 1)\n",
" return max_preds.item()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JUV5rLcIJUVh",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "cd43a816-433e-4d71-ffb3-974f75f6cda1"
},
"source": [
"pred_class = predict_class(model, \"Who is Keyser Söze?\")\n",
"print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
],
"execution_count": 27,
"outputs": [
{
"output_type": "stream",
"text": [
"Predicted class is: 23 = HUM:desc\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "lfqNlaTkJWf0",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "b90ee9eb-4bc6-4a01-95cf-6de89bae0036"
},
"source": [
"pred_class = predict_class(model, \"How many minutes are in six hundred and eighteen hours?\")\n",
"print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
],
"execution_count": 28,
"outputs": [
{
"output_type": "stream",
"text": [
"Predicted class is: 3 = NUM:count\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "HNPZ59e_JbXx",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "93383363-e477-40fa-f303-b202995fd00c"
},
"source": [
"pred_class = predict_class(model, \"What continent is Bulgaria in?\")\n",
"print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
],
"execution_count": 29,
"outputs": [
{
"output_type": "stream",
"text": [
"Predicted class is: 1 = LOC:other\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "M3l56FQvJdLq",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "a37dd7bd-ad6e-47ff-ae9d-c0c633139256"
},
"source": [
"pred_class = predict_class(model, \"What does WYSIWYG stand for?\")\n",
"print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
],
"execution_count": 30,
"outputs": [
{
"output_type": "stream",
"text": [
"Predicted class is: 17 = ABBR:exp\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "swn5acMVJe_U",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "28548e80-f2e0-46b1-fc68-e8451be62070"
},
"source": [
"pred_class = predict_class(model, \"Where is New York City?\")\n",
"print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
],
"execution_count": 31,
"outputs": [
{
"output_type": "stream",
"text": [
"Predicted class is: 1 = LOC:other\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Gm65tDkyJj7Q",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment