Created
May 14, 2020 13:00
-
-
Save ajason08/f5028a6b929c4fef6596f6e3aff5d69d to your computer and use it in GitHub Desktop.
sentimix(sent).ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "sentimix(sent).ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"mount_file_id": "1O1rfo_liszIt7f1ZxITxsNCVLAkwb0u0", | |
"authorship_tag": "ABX9TyPyTE9aPX+DSfk/Dk+2sExw", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ajason08/f5028a6b929c4fef6596f6e3aff5d69d/sentimix-sent.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "StdpsNK5Ynv_", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import torch\n", | |
"from torchtext import data\n", | |
"from torchtext import datasets\n", | |
"import random\n", | |
"#from torchtext import \n", | |
"\n", | |
"SEED = 1234\n", | |
"torch.manual_seed(SEED)\n", | |
"torch.backends.cudnn.deterministic = True" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "AumXjOCoiChv", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"**Data loading**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "uAS0N5CmoQzK", | |
"colab_type": "code", | |
"outputId": "92b16726-e56f-4ed8-a545-c9a39061c21e", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 93 | |
} | |
}, | |
"source": [ | |
"ls \"/content/drive/My Drive/aDrive08/Mind/1_Objectives/sentimix/sentimix_colab/\"" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"answer.txt sentimix_ek_test.tsv sentimix_no-ek_dv.csv\n", | |
"\u001b[0m\u001b[01;34mkfold\u001b[0m/ sentimix_ek_train_dev.csv sentimix_no-ek_tr.csv\n", | |
"sentimix20april.ipynb sentimix_ek_tr.tsv\n", | |
"sentimix_ek_dv.tsv Sentimix.ipynb\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QqlLV5kshTuy", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"mypath = \"/content/drive/My Drive/aDrive08/Mind/1_Objectives/sentimix_colab/\"\n", | |
"\n", | |
"myfilenamet= \"sentimix_no-ek_tr.csv\" # \"sentimix_ek_tr.tsv\" # \"sentimix_ek_train_dev.csv\" \n", | |
"myfilenamed= \"sentimix_no-ek_dv.csv\" # \"sentimix_ek_dv.tsv\" # \"sentimix_ek_test.tsv\" \n", | |
"\n", | |
"\n", | |
"# input in torchtext\n", | |
"TEXT = data.Field(tokenize = 'spacy', lower= True)\n", | |
"LABEL = data.LabelField()\n", | |
"\n", | |
"train_data, dev_data = data.TabularDataset.splits(\n", | |
" path= mypath, train= myfilenamet, test = myfilenamed, format='tsv',\n", | |
" fields=[('label', LABEL),\n", | |
" ('tweet', TEXT)],\n", | |
" skip_header=True)\n", | |
"train_data, valid_data = train_data.split(stratified=True, random_state = random.seed(SEED))" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "CHhNoOAWsqH5", | |
"colab_type": "code", | |
"outputId": "4b11c6c8-1d91-47e2-ef2d-624d7ce1b56a", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 260 | |
} | |
}, | |
"source": [ | |
"# # from https://www.spinningbytes.com/resources/wordembeddings\n", | |
"# !curl -o myembed http://4530.hostserv.eu/resources/embed_tweets_es_200M_200D.zip\n", | |
"# !unzip myembed\n", | |
"\n", | |
"from torchtext.vocab import Vectors as myCustomEmbeddings\n", | |
"\n", | |
"vec = myCustomEmbeddings('/content/es_embeddings_200M_200d/embedding_file')\n", | |
"MAX_VOCAB_SIZE = 15_000\n", | |
"\n", | |
"TEXT.build_vocab(train_data, \n", | |
" max_size = MAX_VOCAB_SIZE, \n", | |
" vectors = vec, \n", | |
" unk_init = torch.Tensor.normal_)\n", | |
"LABEL.build_vocab(train_data)\n" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
" % Total % Received % Xferd Average Speed Time Time Time Current\n", | |
" Dload Upload Total Spent Left Speed\n", | |
"100 891M 100 891M 0 0 8400k 0 0:01:48 0:01:48 --:--:-- 8599k\n", | |
"Archive: myembed\n", | |
" creating: es_embeddings_200M_200d/\n", | |
" inflating: es_embeddings_200M_200d/vocabulary.pickle \n", | |
" inflating: es_embeddings_200M_200d/1gram \n", | |
" inflating: es_embeddings_200M_200d/embedding_matrix.npy \n", | |
" inflating: es_embeddings_200M_200d/0gram \n", | |
" inflating: es_embeddings_200M_200d/config.json \n", | |
" inflating: es_embeddings_200M_200d/embedding_file \n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
" 0%| | 0/201382 [00:00<?, ?it/s]Skipping token b'201382' with 1-dimensional vector [b'200']; likely a header\n", | |
"100%|█████████▉| 200636/201382 [00:16<00:00, 12013.90it/s]" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Nw5bhDpFbA34", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"**Data Exploration**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "S1pE5OhGiiqT", | |
"colab_type": "code", | |
"outputId": "a0a2b475-49de-4fad-a8f7-5f3ab753ac8d", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 504 | |
} | |
}, | |
"source": [ | |
"# # CSV level\n", | |
"# print(\"=================================== At CSV level:\")\n", | |
"# import pandas as pd \n", | |
"# df=pd.read_csv(mypath + myfilenamet,sep=\"\\t\")\n", | |
"# print(df.head(5))\n", | |
"\n", | |
"\n", | |
"# # torch dataset level\n", | |
"# print(\"\\n=================================== At Torch dataset level:\")\n", | |
"# print(\"First labels in torch_field: \")\n", | |
"# for i,j in enumerate(valid_data.label):\n", | |
"# print (i,j)\n", | |
"# if i==4: break\n", | |
"\n", | |
"# print(f'Dataset lens: \\n\\t train: {len(train_data)},\\\n", | |
"# \\n\\t val: {len(valid_data)},\\\n", | |
"# \\n\\t dev: {len(dev_data)},\\\n", | |
"# ')\n", | |
"# print(f'Last sample in train data \\n\\t{vars(train_data[-1])}')\n", | |
"\n", | |
"# # vocab level \n", | |
"# print(\"\\n=================================== At Torch vocab level:\")\n", | |
"# print(f'Vocab: {len(TEXT.vocab)}\\n label map:{LABEL.vocab.stoi}')" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"=================================== At CSV level:\n", | |
" Col2 Col0\n", | |
"0 positive So that means tomorrow cruda segura lol\n", | |
"1 positive Un MUST HAVE en esta navidad son estos aretes ...\n", | |
"2 positive RT @Univision : ¡ Ya llego @bombaestereo ! @Ve...\n", | |
"3 negative @javiidiaz06 No puedo mano , he tratado con hi...\n", | |
"4 neutral @mirandaalexa17 hahahahahaha vpc 😂 a mi hasta ...\n", | |
"\n", | |
"=================================== At Torch dataset level:\n", | |
"First labels in torch_field: \n", | |
"0 negative\n", | |
"1 negative\n", | |
"2 negative\n", | |
"3 negative\n", | |
"4 negative\n", | |
"Dataset lens: \n", | |
"\t train: 8402, \n", | |
"\t val: 3600, \n", | |
"\t dev: 2998, \n", | |
"Last sample in train data \n", | |
"\t{'label': 'positive', 'tweet': ['dios', 'mio', ',', 'mi', 'dentista', 'es', 'una', 'diosa', 'bajada', 'del', 'olimpo', '.', '#', 'blessedher']}\n", | |
"\n", | |
"=================================== At Torch vocab level:\n", | |
"Vocab: 15002\n", | |
" label map:defaultdict(<function _default_unk_index at 0x7f92294c8ae8>, {'positive': 0, 'neutral': 1, 'negative': 2})\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "GlvHbIqpi7N8", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"**My model**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "sWii9vO7hmgR", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"class CNN(nn.Module):\n", | |
" def __init__(self, vocab_size, output_dim, embedding_dim, pad_idx, \n", | |
" n_filters, filter_sizes, dropout):\n", | |
" \n", | |
" super().__init__()\n", | |
" \n", | |
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n", | |
" \n", | |
" self.convs = nn.ModuleList([\n", | |
" nn.Conv2d(in_channels = 1, \n", | |
" out_channels = n_filters, \n", | |
" kernel_size = (fs, embedding_dim)) \n", | |
" for fs in filter_sizes\n", | |
" ])\n", | |
" \n", | |
" self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n", | |
" \n", | |
" self.dropout = nn.Dropout(dropout)\n", | |
" \n", | |
" def forward(self, text):\n", | |
" \n", | |
" #text = [sent len, batch size]\n", | |
" \n", | |
" text = text.permute(1, 0)\n", | |
" \n", | |
" #text = [batch size, sent len]\n", | |
" \n", | |
" embedded = self.embedding(text)\n", | |
" \n", | |
" #embedded = [batch size, sent len, emb dim]\n", | |
" \n", | |
" embedded = embedded.unsqueeze(1)\n", | |
" \n", | |
" #embedded = [batch size, 1, sent len, emb dim]\n", | |
" \n", | |
" conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]\n", | |
" \n", | |
" #conv_n = [batch size, n_filters, sent len - filter_sizes[n]]\n", | |
" \n", | |
" pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n", | |
" \n", | |
" #pooled_n = [batch size, n_filters]\n", | |
" \n", | |
" cat = self.dropout(torch.cat(pooled, dim = 1))\n", | |
"\n", | |
" #cat = [batch size, n_filters * len(filter_sizes)]\n", | |
" \n", | |
" return self.fc(cat)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "7rh-ex2NmToZ", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"**Training and prediction:** Strategy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wHDCR3nzqJ6D", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import time\n", | |
"def epoch_time(start_time, end_time):\n", | |
" elapsed_time = end_time - start_time\n", | |
" elapsed_mins = int(elapsed_time / 60)\n", | |
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n", | |
" return elapsed_mins, elapsed_secs\n", | |
"\n", | |
"def categorical_accuracy(preds, y):\n", | |
" \"\"\"\n", | |
" Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n", | |
" \"\"\"\n", | |
" max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability\n", | |
" correct = max_preds.squeeze(1).eq(y)\n", | |
" return correct.sum() / torch.FloatTensor([y.shape[0]])\n", | |
"\n", | |
"def train(model, iterator, optimizer, criterion):\n", | |
" epoch_loss = 0\n", | |
" epoch_acc = 0 \n", | |
" model.train()\n", | |
"\n", | |
" #predict for updating weights \n", | |
" for batch in iterator: \n", | |
" optimizer.zero_grad() \n", | |
" \n", | |
" predictions = model(batch.tweet) \n", | |
" loss = criterion(predictions, batch.label) \n", | |
" acc = categorical_accuracy(predictions, batch.label)\n", | |
" \n", | |
" loss.backward() \n", | |
" optimizer.step()\n", | |
" \n", | |
" epoch_loss += loss.item()\n", | |
" epoch_acc += acc.item() \n", | |
" return epoch_loss / len(iterator), epoch_acc / len(iterator)\n", | |
"\n", | |
"def evaluate(model, iterator, criterion): \n", | |
" epoch_loss = 0\n", | |
" epoch_acc = 0 \n", | |
" model.eval()\n", | |
" \n", | |
" with torch.no_grad(): # predict without update weights\n", | |
" for batch in iterator:\n", | |
"\n", | |
" predictions = model(batch.tweet) \n", | |
" loss = criterion(predictions, batch.label) \n", | |
" acc = categorical_accuracy(predictions, batch.label)\n", | |
"\n", | |
" epoch_loss += loss.item()\n", | |
" epoch_acc += acc.item()\n", | |
" \n", | |
" return epoch_loss / len(iterator), epoch_acc / len(iterator)\n", | |
"\n", | |
"\n", | |
"def predict_testset(model,iterator): \n", | |
" model.eval()\n", | |
" mypredictions = []\n", | |
" correct_answers = []\n", | |
" with torch.no_grad(): # predict without update weights\n", | |
" for batch in iterator:\n", | |
" predictions = model(batch.tweet) \n", | |
" max_preds = predictions.argmax(dim = 1)\n", | |
" mypredictions.append(max_preds)\n", | |
" correct_answers.append(batch.label)\n", | |
" mypredictions = torch.cat(mypredictions).tolist()\n", | |
" correct_answers = torch.cat(correct_answers).tolist() \n", | |
" return mypredictions, correct_answers\n", | |
"\n", | |
"\n", | |
"import spacy\n", | |
"nlp = spacy.load('en')\n", | |
"\n", | |
"def predict_custom_example(model, sentence, min_len = 4):\n", | |
" model.eval()\n", | |
" tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n", | |
" if len(tokenized) < min_len:\n", | |
" tokenized += ['<pad>'] * (min_len - len(tokenized))\n", | |
" indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n", | |
" tensor = torch.LongTensor(indexed).to(device)\n", | |
" tensor = tensor.unsqueeze(1)\n", | |
" preds = model(tensor)\n", | |
" max_preds = preds.argmax(dim = 1)\n", | |
" return max_preds.item()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ebVeMuOzXPgZ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# iterator\n", | |
"BATCH_SIZE = 64\n", | |
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
"\n", | |
"train_iterator, valid_iterator, dev_iterator = data.BucketIterator.splits(\n", | |
" (train_data, valid_data, dev_data), \n", | |
" sort = False, \n", | |
" batch_size = BATCH_SIZE, \n", | |
" device = device)\n", | |
"\n", | |
"# model main arguments\n", | |
"INPUT_DIM = len(TEXT.vocab)\n", | |
"OUTPUT_DIM = len(LABEL.vocab)\n", | |
"EMBEDDING_DIM = len(TEXT.vocab.vectors[0])\n", | |
"PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n", | |
"\n", | |
"# Model hyperparameters\n", | |
"N_FILTERS = 100\n", | |
"FILTER_SIZES = [2,3,4]\n", | |
"DROPOUT = 0.5\n", | |
"\n", | |
"model = CNN(INPUT_DIM, OUTPUT_DIM, EMBEDDING_DIM, PAD_IDX, N_FILTERS, FILTER_SIZES, DROPOUT)\n", | |
"\n", | |
"# initial weights\n", | |
"UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n", | |
"model.embedding.weight.data.copy_(TEXT.vocab.vectors)\n", | |
"model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n", | |
"model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)\n", | |
"\n", | |
"# optimizer and criterion\n", | |
"import torch.optim as optim\n", | |
"\n", | |
"optimizer = optim.Adam(model.parameters())\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"\n", | |
"model = model.to(device)\n", | |
"criterion = criterion.to(device)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "L2xcbjr6sA5u", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"**Training and predicting** : Execution " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "A0Sdags8YwKa", | |
"colab_type": "code", | |
"outputId": "89b97c26-db91-4354-8bb2-f0a930164a3a", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 185 | |
} | |
}, | |
"source": [ | |
"N_EPOCHS = 5\n", | |
"best_valid_loss = float('inf')\n", | |
"\n", | |
"for epoch in range(N_EPOCHS):\n", | |
" start_time = time.time()\n", | |
" \n", | |
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n", | |
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n", | |
" \n", | |
" end_time = time.time()\n", | |
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n", | |
" \n", | |
" if valid_loss < best_valid_loss:\n", | |
" best_valid_loss = valid_loss\n", | |
" torch.save(model.state_dict(), 'tut5-model.pt')\n", | |
" \n", | |
" print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n", | |
" print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n", | |
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\r100%|█████████▉| 200636/201382 [00:30<00:00, 12013.90it/s]" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 01 | Epoch Time: 0m 23s\n", | |
"\tTrain Loss: 0.989 | Train Acc: 50.34%\n", | |
"\t Val. Loss: 0.944 | Val. Acc: 52.03%\n", | |
"Epoch: 02 | Epoch Time: 0m 22s\n", | |
"\tTrain Loss: 0.879 | Train Acc: 57.14%\n", | |
"\t Val. Loss: 0.916 | Val. Acc: 54.28%\n", | |
"Epoch: 03 | Epoch Time: 0m 22s\n", | |
"\tTrain Loss: 0.724 | Train Acc: 69.19%\n", | |
"\t Val. Loss: 0.966 | Val. Acc: 52.58%\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "IcksuDj6ms0U", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"**Initial predictions**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zNBaPRpZvfgA", | |
"colab_type": "code", | |
"outputId": "163fd144-93d4-4594-89c9-74ce132f3513", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
} | |
}, | |
"source": [ | |
"# predict custom example\n", | |
"my_custom_example = \"\"\"\n", | |
"<user> : voy a llorar <elongated>\n", | |
"\"\"\"\n", | |
"pred_class = predict_custom_example(model, my_custom_example)\n", | |
"print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Predicted class is: 0 = positive\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "MOyzq1m_vdsE", | |
"colab_type": "code", | |
"outputId": "6b4014da-0261-43fc-c96d-f27468c5783e", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
} | |
}, | |
"source": [ | |
"# predict dev/test set\n", | |
"model.load_state_dict(torch.load('tut5-model.pt'))\n", | |
"test_loss, test_acc = evaluate(model, valid_iterator, criterion)\n", | |
"print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Test Loss: 0.922 | Test Acc: 54.21%\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ZYSFfQignSrA", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"**Dev-test set predictions**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "8NOSLP1zDyl6", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"model.load_state_dict(torch.load('tut5-model.pt'))\n", | |
"mypredictions, correct_answers = predict_testset(model,iterator=dev_iterator)\n", | |
"# \"human labels\", the original string labels\n", | |
"mypredictionsH = [LABEL.vocab.itos[pred_class] for pred_class in mypredictions]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "T839ek71zJGW", | |
"colab_type": "code", | |
"outputId": "ab971ced-d191-43de-a3ae-0210b6a78646", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 204 | |
} | |
}, | |
"source": [ | |
"import pandas as pd\n", | |
"from sklearn.metrics import confusion_matrix, precision_recall_fscore_support\n", | |
"\n", | |
"cmatrix = confusion_matrix(correct_answers, mypredictions, \n", | |
" labels= list(LABEL.vocab.stoi.values()))\n", | |
"title = LABEL.vocab.itos\n", | |
"cmtx = pd.DataFrame(cmatrix, index=title, columns=title)\n", | |
"#cmtx.index.name = 'Gold'\n", | |
"cmtx.columns.name = 'Gold \\ Pred'\n", | |
"print (cmtx)\n", | |
"\n", | |
"\n", | |
"myavg = None #\"macro\" # None micro macro weighted\n", | |
"scores = precision_recall_fscore_support(correct_answers, mypredictions, average= myavg)\n", | |
"scores_labels = [\"precision\", \"recall\", \"fscore\", \"support\"]\n", | |
"scores_dict = dict(zip(scores_labels, scores)) \n", | |
"scores_df = pd.DataFrame(scores_dict, index = title)\n", | |
"scores_df.columns.name = 'class \\ metric'\n", | |
"print (f'\\nAnalysis using average = {myavg}')\n", | |
"print (scores_df)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Gold \\ Pred positive neutral negative\n", | |
"positive 909 349 240\n", | |
"neutral 438 354 202\n", | |
"negative 120 157 229\n", | |
"\n", | |
"Analysis using average = None\n", | |
"class \\ metric precision recall fscore support\n", | |
"positive 0.619632 0.606809 0.613153 1498\n", | |
"neutral 0.411628 0.356137 0.381877 994\n", | |
"negative 0.341282 0.452569 0.389125 506\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "l2DQVLjjnZDh", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Saving answers for delivery" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "1mYZrUMnt8V-", | |
"colab_type": "code", | |
"outputId": "e9f032a4-75af-4af3-b6c1-2357c3dcf930", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
} | |
}, | |
"source": [ | |
"# CSV level\n", | |
"import pandas as pd \n", | |
"df=pd.read_csv(mypath + myfilenamed,sep=\"\\t\")\n", | |
"df.head(5)\n", | |
"df[\"prediction\"] = mypredictionsH\n", | |
"del df['ekphrased']\n", | |
"df.columns = ['Uid', 'Sentiment']\n", | |
"df.head(5)\n", | |
"df.to_csv(mypath + \"answer.txt\", sep=\",\", index=False)\n", | |
"\n", | |
"df2=pd.read_csv(mypath + \"answer.txt\",sep=\",\")\n", | |
"df2.head(5)\n", | |
"df2.shape" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(3789, 2)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 75 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5B565xCauW0B", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment