Created
February 5, 2020 16:24
-
-
Save SharanSMenon/89137be312ad3bb2d9402ed1659a42ca to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "question-classification-cnn.ipynb", | |
"provenance": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "bFGMr3CQGFJz", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import torch\n", | |
"from torchtext import data\n", | |
"from torchtext import datasets\n", | |
"import random\n", | |
"\n", | |
"SEED = 1234\n", | |
"\n", | |
"torch.manual_seed(SEED)\n", | |
"torch.backends.cudnn.deterministic = True\n", | |
"\n", | |
"TEXT = data.Field(tokenize = 'spacy')\n", | |
"LABEL = data.LabelField()\n", | |
"\n", | |
"train_data, test_data = datasets.TREC.splits(TEXT, LABEL, fine_grained=True)\n", | |
"\n", | |
"train_data, valid_data = train_data.split(random_state = random.seed(SEED))" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "oDMpsS7oGH2V", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "c8814a4d-15fb-4b9f-a5d8-d8243096c3d8" | |
}, | |
"source": [ | |
"vars(train_data[-1])" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{'label': 'DESC:def', 'text': ['What', 'is', 'a', 'Cartesian', 'Diver', '?']}" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "laidwNkvGRwX", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 51 | |
}, | |
"outputId": "bc61e40f-82c2-4f3a-f091-07e61f507608" | |
}, | |
"source": [ | |
"MAX_VOCAB_SIZE = 45_000\n", | |
"\n", | |
"TEXT.build_vocab(train_data, \n", | |
" max_size = MAX_VOCAB_SIZE, \n", | |
" vectors = \"glove.6B.100d\", \n", | |
" unk_init = torch.Tensor.normal_)\n", | |
"\n", | |
"LABEL.build_vocab(train_data)" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
".vector_cache/glove.6B.zip: 862MB [06:32, 2.20MB/s] \n", | |
"100%|█████████▉| 398009/400000 [00:16<00:00, 24050.78it/s]" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "07mXbobyGghA", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 54 | |
}, | |
"outputId": "e86b6e4b-a093-4015-cd66-0915b7e75bfd" | |
}, | |
"source": [ | |
"print(LABEL.vocab.stoi)" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"defaultdict(<function _default_unk_index at 0x7f7fb0960158>, {'HUM:ind': 0, 'LOC:other': 1, 'DESC:def': 2, 'NUM:count': 3, 'DESC:manner': 4, 'DESC:desc': 5, 'NUM:date': 6, 'DESC:reason': 7, 'HUM:gr': 8, 'ENTY:other': 9, 'ENTY:cremat': 10, 'LOC:country': 11, 'LOC:city': 12, 'ENTY:animal': 13, 'ENTY:dismed': 14, 'ENTY:food': 15, 'ENTY:termeq': 16, 'ABBR:exp': 17, 'NUM:money': 18, 'NUM:period': 19, 'LOC:state': 20, 'ENTY:event': 21, 'ENTY:sport': 22, 'HUM:desc': 23, 'NUM:other': 24, 'ENTY:product': 25, 'ENTY:color': 26, 'ENTY:techmeth': 27, 'ENTY:substance': 28, 'ENTY:word': 29, 'ENTY:veh': 30, 'NUM:dist': 31, 'HUM:title': 32, 'NUM:perc': 33, 'LOC:mount': 34, 'ENTY:body': 35, 'ABBR:abb': 36, 'ENTY:lang': 37, 'ENTY:instru': 38, 'ENTY:plant': 39, 'NUM:code': 40, 'NUM:temp': 41, 'NUM:volsize': 42, 'NUM:weight': 43, 'ENTY:letter': 44, 'ENTY:symbol': 45, 'ENTY:religion': 46, 'NUM:ord': 47, 'NUM:speed': 48, 'ENTY:currency': 49})\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "gKQbNTskIQGp", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"BATCH_SIZE = 64\n", | |
"\n", | |
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
"\n", | |
"train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n", | |
" (train_data, valid_data, test_data), \n", | |
" batch_size = BATCH_SIZE, \n", | |
" device = device)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "PjCKXbh5IW-D", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "a41eb153-a353-494d-b2f3-e4133b6b4c77" | |
}, | |
"source": [ | |
"device" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"device(type='cuda')" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "FjdHxNeGIX3E", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ZqBeymR9IaKS", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"class CNN(nn.Module):\n", | |
" def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, \n", | |
" dropout, pad_idx): \n", | |
" super().__init__() \n", | |
" self.embedding = nn.Embedding(vocab_size, embedding_dim) \n", | |
" self.convs = nn.ModuleList([\n", | |
" nn.Conv2d(in_channels = 1, \n", | |
" out_channels = n_filters, \n", | |
" kernel_size = (fs, embedding_dim)) \n", | |
" for fs in filter_sizes\n", | |
" ]) \n", | |
" self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim) \n", | |
" self.dropout = nn.Dropout(dropout)\n", | |
" \n", | |
" def forward(self, text):\n", | |
" text = text.permute(1, 0)\n", | |
" embedded = self.embedding(text)\n", | |
" embedded = embedded.unsqueeze(1)\n", | |
" conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]\n", | |
" pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n", | |
" cat = self.dropout(torch.cat(pooled, dim = 1))\n", | |
" return self.fc(cat)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "YLUvwoEmIkZ4", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"INPUT_DIM = len(TEXT.vocab)\n", | |
"EMBEDDING_DIM = 100\n", | |
"N_FILTERS = 100\n", | |
"FILTER_SIZES = [2,3,4]\n", | |
"OUTPUT_DIM = len(LABEL.vocab)\n", | |
"DROPOUT = 0.5\n", | |
"PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n", | |
"\n", | |
"model = CNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5Asdh-TZImgX", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 187 | |
}, | |
"outputId": "0cb48953-4a8c-4051-8058-ffec3d6e4fdf" | |
}, | |
"source": [ | |
"print(model)" | |
], | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"CNN(\n", | |
" (embedding): Embedding(7503, 100)\n", | |
" (convs): ModuleList(\n", | |
" (0): Conv2d(1, 100, kernel_size=(2, 100), stride=(1, 1))\n", | |
" (1): Conv2d(1, 100, kernel_size=(3, 100), stride=(1, 1))\n", | |
" (2): Conv2d(1, 100, kernel_size=(4, 100), stride=(1, 1))\n", | |
" )\n", | |
" (fc): Linear(in_features=300, out_features=50, bias=True)\n", | |
" (dropout): Dropout(p=0.5, inplace=False)\n", | |
")\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "A6YCAJKmIn5i", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "75060a64-db33-4a6f-ddda-7f4e0c1520e0" | |
}, | |
"source": [ | |
"def count_parameters(model):\n", | |
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", | |
"\n", | |
"print(f'The model has {count_parameters(model):,} trainable parameters')" | |
], | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"The model has 855,650 trainable parameters\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Q3OmgQkzIpuq", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 136 | |
}, | |
"outputId": "0a84a28c-4188-47e1-83ff-1e89de0c4133" | |
}, | |
"source": [ | |
"pretrained_embeddings = TEXT.vocab.vectors\n", | |
"\n", | |
"model.embedding.weight.data.copy_(pretrained_embeddings)" | |
], | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[-0.1117, -0.4966, 0.1631, ..., 1.2647, -0.2753, -0.1325],\n", | |
" [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n", | |
" [ 0.1638, 0.6046, 1.0789, ..., -0.3140, 0.1844, 0.3624],\n", | |
" ...,\n", | |
" [-0.3110, -0.3398, 1.0308, ..., 0.5317, 0.2836, -0.0640],\n", | |
" [ 0.0091, 0.2810, 0.7356, ..., -0.7508, 0.8967, -0.7631],\n", | |
" [ 0.4306, 1.2011, 0.0873, ..., 0.8817, 0.3722, 0.3458]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 14 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ZmzT0-vNIrga", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n", | |
"\n", | |
"model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n", | |
"model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "miLY5rvXIs_s", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import torch.optim as optim\n", | |
"\n", | |
"optimizer = optim.Adam(model.parameters())\n", | |
"\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"\n", | |
"model = model.to(device)\n", | |
"criterion = criterion.to(device)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dC5vB7DYIuyV", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def categorical_accuracy(preds, y):\n", | |
" max_preds = preds.argmax(dim = 1, keepdim = True)\n", | |
" correct = max_preds.squeeze(1).eq(y)\n", | |
" return correct.sum() / torch.FloatTensor([y.shape[0]])" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "r_ING1JEIyIx", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def train(model, iterator, optimizer, criterion):\n", | |
" epoch_loss = 0\n", | |
" epoch_acc = 0\n", | |
" model.train()\n", | |
" for batch in iterator:\n", | |
" optimizer.zero_grad()\n", | |
" predictions = model(batch.text)\n", | |
" loss = criterion(predictions, batch.label)\n", | |
" acc = categorical_accuracy(predictions, batch.label)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" epoch_loss += loss.item()\n", | |
" epoch_acc += acc.item()\n", | |
" return epoch_loss / len(iterator), epoch_acc / len(iterator)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "TtfacWc-I4D4", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def evaluate(model, iterator, criterion):\n", | |
" epoch_loss = 0\n", | |
" epoch_acc = 0\n", | |
" model.eval()\n", | |
" with torch.no_grad():\n", | |
" for batch in iterator:\n", | |
" predictions = model(batch.text)\n", | |
" loss = criterion(predictions, batch.label)\n", | |
" acc = categorical_accuracy(predictions, batch.label)\n", | |
" epoch_loss += loss.item()\n", | |
" epoch_acc += acc.item()\n", | |
" return epoch_loss / len(iterator), epoch_acc / len(iterator)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "s1gqEvvjI9OY", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import time\n", | |
"\n", | |
"def epoch_time(start_time, end_time):\n", | |
" elapsed_time = end_time - start_time\n", | |
" elapsed_mins = int(elapsed_time / 60)\n", | |
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n", | |
" return elapsed_mins, elapsed_secs" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dlTp7h7PI-yw", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 170 | |
}, | |
"outputId": "68c7b04c-9402-4d3f-a018-c6c31cbc5c1a" | |
}, | |
"source": [ | |
"N_EPOCHS = 18\n", | |
"\n", | |
"best_valid_loss = float('inf')\n", | |
"\n", | |
"for epoch in range(N_EPOCHS):\n", | |
"\n", | |
" start_time = time.time()\n", | |
" \n", | |
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n", | |
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n", | |
" \n", | |
" end_time = time.time()\n", | |
"\n", | |
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n", | |
" \n", | |
" if valid_loss < best_valid_loss:\n", | |
" best_valid_loss = valid_loss\n", | |
" torch.save(model.state_dict(), 'tut5-model.pt')\n", | |
" \n", | |
" print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n", | |
" print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n", | |
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')" | |
], | |
"execution_count": 24, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 01 | Epoch Time: 0m 0s\n", | |
"\tTrain Loss: 0.166 | Train Acc: 96.91%\n", | |
"\t Val. Loss: 0.924 | Val. Acc: 76.97%\n", | |
"Epoch: 02 | Epoch Time: 0m 0s\n", | |
"\tTrain Loss: 0.138 | Train Acc: 97.78%\n", | |
"\t Val. Loss: 0.926 | Val. Acc: 76.14%\n", | |
"Epoch: 03 | Epoch Time: 0m 0s\n", | |
"\tTrain Loss: 0.122 | Train Acc: 97.84%\n", | |
"\t Val. Loss: 0.918 | Val. Acc: 76.74%\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "qV4_NbJ_JCn6", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "3d954780-057d-41e5-ffd9-eefaf50904ef" | |
}, | |
"source": [ | |
"model.load_state_dict(torch.load('tut5-model.pt'))\n", | |
"\n", | |
"test_loss, test_acc = evaluate(model, test_iterator, criterion)\n", | |
"\n", | |
"print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')" | |
], | |
"execution_count": 25, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Test Loss: 0.888 | Test Acc: 75.78%\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xdjKc2K6JSLl", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import spacy\n", | |
"nlp = spacy.load('en')\n", | |
"\n", | |
"def predict_class(model, sentence, min_len = 4):\n", | |
" model.eval()\n", | |
" tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n", | |
" if len(tokenized) < min_len:\n", | |
" tokenized += ['<pad>'] * (min_len - len(tokenized))\n", | |
" indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n", | |
" tensor = torch.LongTensor(indexed).to(device)\n", | |
" tensor = tensor.unsqueeze(1)\n", | |
" preds = model(tensor)\n", | |
" max_preds = preds.argmax(dim = 1)\n", | |
" return max_preds.item()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JUV5rLcIJUVh", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "cd43a816-433e-4d71-ffb3-974f75f6cda1" | |
}, | |
"source": [ | |
"pred_class = predict_class(model, \"Who is Keyser Söze?\")\n", | |
"print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" | |
], | |
"execution_count": 27, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Predicted class is: 23 = HUM:desc\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "lfqNlaTkJWf0", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "b90ee9eb-4bc6-4a01-95cf-6de89bae0036" | |
}, | |
"source": [ | |
"pred_class = predict_class(model, \"How many minutes are in six hundred and eighteen hours?\")\n", | |
"print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" | |
], | |
"execution_count": 28, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Predicted class is: 3 = NUM:count\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "HNPZ59e_JbXx", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "93383363-e477-40fa-f303-b202995fd00c" | |
}, | |
"source": [ | |
"pred_class = predict_class(model, \"What continent is Bulgaria in?\")\n", | |
"print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" | |
], | |
"execution_count": 29, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Predicted class is: 1 = LOC:other\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "M3l56FQvJdLq", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "a37dd7bd-ad6e-47ff-ae9d-c0c633139256" | |
}, | |
"source": [ | |
"pred_class = predict_class(model, \"What does WYSIWYG stand for?\")\n", | |
"print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" | |
], | |
"execution_count": 30, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Predicted class is: 17 = ABBR:exp\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "swn5acMVJe_U", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "28548e80-f2e0-46b1-fc68-e8451be62070" | |
}, | |
"source": [ | |
"pred_class = predict_class(model, \"Where is New York City?\")\n", | |
"print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" | |
], | |
"execution_count": 31, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Predicted class is: 1 = LOC:other\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Gm65tDkyJj7Q", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment