Created
December 4, 2023 14:21
-
-
Save georgehc/24ed6def3c903e32ea12eb5796d76e40 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 95-865: Sentiment Analysis with IMDb Reviews\n", | |
"\n", | |
"Author: George H. Chen (georgechen [at symbol] cmu.edu)\n", | |
"\n", | |
"Last updated: Dec 4, 2023\n", | |
"\n", | |
"This demo shows how to train an LSTM model for sentiment analysis with IMDb reviews. This is a binary classification task: for each review, we classify it as having positive or negative sentiment." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib inline\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"import random\n", | |
"import os\n", | |
"\n", | |
"os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # to help make code deterministic\n", | |
"\n", | |
"from glob import glob\n", | |
"\n", | |
"import torch\n", | |
"torch.use_deterministic_algorithms(True) # to help make code deterministic\n", | |
"torch.backends.cudnn.benchmark = False # to help make code deterministic\n", | |
"import torch.nn as nn\n", | |
"from torchinfo import summary\n", | |
"\n", | |
"np.random.seed(0) # to help make code deterministic\n", | |
"torch.manual_seed(0) # to help make code deterministic\n", | |
"random.seed(0) # to help make code deterministic\n", | |
"\n", | |
"from UDA_pytorch_utils import UDA_pytorch_classifier_fit, \\\n", | |
" UDA_plot_train_val_accuracy_vs_epoch, UDA_pytorch_classifier_predict, \\\n", | |
" UDA_compute_accuracy, UDA_get_rnn_last_time_step_outputs\n", | |
"\n", | |
"# these next two lines are needed on my old Intel Mac laptop due to some weird software update issue and also a memory issue\n", | |
"import os\n", | |
"os.environ['LOKY_MAX_CPU_COUNT'] = '1'" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Load the dataset\n", | |
"\n", | |
"Here, we downloaded the IMDb dataset from: http://ai.stanford.edu/~amaas/data/sentiment/\n", | |
"\n", | |
"We place the file `aclImdb_v1.tar.gz` into `./data/` and uncompress the file within that directory so that after uncompressing, you should have access to the directories `./data/aclImdb/train`, `./data/aclImdb/test`, and other files such as the \"README\" file `./data/aclImdb/README`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_dataset = []\n", | |
"\n", | |
"for filename in sorted(glob('./data/aclImdb/train/pos/*.txt')):\n", | |
" with open(filename, 'r', encoding='utf-8') as f:\n", | |
" train_dataset.append((f.read(), 1)) # 1 means `positive` sentiment\n", | |
"\n", | |
"for filename in sorted(glob('./data/aclImdb/train/neg/*.txt')):\n", | |
" with open(filename, 'r', encoding='utf-8') as f:\n", | |
" train_dataset.append((f.read(), 0)) # 0 means `negative` sentiment" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"# proper training data points: 20000\n", | |
"# validation data points: 5000\n" | |
] | |
} | |
], | |
"source": [ | |
"proper_train_size = int(len(train_dataset) * 0.8)\n", | |
"val_size = len(train_dataset) - proper_train_size\n", | |
"print('# proper training data points:', proper_train_size)\n", | |
"print('# validation data points:', val_size)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"proper_train_dataset, val_dataset = torch.utils.data.random_split(train_dataset,\n", | |
" [proper_train_size,\n", | |
" val_size])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(\"Master cinéaste Alain Resnais likes to work with those actors who are a part of his family.In this film too we see Resnais' family members like Pierre Arditi, Sabine Azema, André Dussolier and Fanny Ardant dealing with serious themes like death,religion,suicide,love and their overall implications on our daily lives.The formal nature of relationship shared by these people is evident as even friends, they address each other using a formal you.In 1984,while making L'amour à mort,Resnais dealt with time,memory and space to unravel the mysteries of a fundamental question of human existence :Is love stronger than death ? It was 16 years ago in 1968 that Resnais made a somewhat similar film Je t'aime Je t'aime which was also about love and memories.Message of this film is loud and clear :true and deep love can even put science to shame as dead lovers regain their lost lives leaving doctors to care for their reputation.L'amour à mort is like a game which is not at all didactic.It is a film in which the musical score is in perfect tandem with its images.This is one of the reasons why this film can easily be grasped.\",\n", | |
" 1)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"proper_train_dataset[0] # this is a tuple of the format (text, label)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"`torchtext` provides a wrapper around spaCy that we could use." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torchtext.data import get_tokenizer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"tokenizer_cased = get_tokenizer('spacy', language='en_core_web_sm')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['Master',\n", | |
" 'cinéaste',\n", | |
" 'Alain',\n", | |
" 'Resnais',\n", | |
" 'likes',\n", | |
" 'to',\n", | |
" 'work',\n", | |
" 'with',\n", | |
" 'those',\n", | |
" 'actors',\n", | |
" 'who',\n", | |
" 'are',\n", | |
" 'a',\n", | |
" 'part',\n", | |
" 'of',\n", | |
" 'his',\n", | |
" 'family',\n", | |
" '.',\n", | |
" 'In',\n", | |
" 'this',\n", | |
" 'film',\n", | |
" 'too',\n", | |
" 'we',\n", | |
" 'see',\n", | |
" 'Resnais',\n", | |
" \"'\",\n", | |
" 'family',\n", | |
" 'members',\n", | |
" 'like',\n", | |
" 'Pierre',\n", | |
" 'Arditi',\n", | |
" ',',\n", | |
" 'Sabine',\n", | |
" 'Azema',\n", | |
" ',',\n", | |
" 'André',\n", | |
" 'Dussolier',\n", | |
" 'and',\n", | |
" 'Fanny',\n", | |
" 'Ardant',\n", | |
" 'dealing',\n", | |
" 'with',\n", | |
" 'serious',\n", | |
" 'themes',\n", | |
" 'like',\n", | |
" 'death',\n", | |
" ',',\n", | |
" 'religion',\n", | |
" ',',\n", | |
" 'suicide',\n", | |
" ',',\n", | |
" 'love',\n", | |
" 'and',\n", | |
" 'their',\n", | |
" 'overall',\n", | |
" 'implications',\n", | |
" 'on',\n", | |
" 'our',\n", | |
" 'daily',\n", | |
" 'lives',\n", | |
" '.',\n", | |
" 'The',\n", | |
" 'formal',\n", | |
" 'nature',\n", | |
" 'of',\n", | |
" 'relationship',\n", | |
" 'shared',\n", | |
" 'by',\n", | |
" 'these',\n", | |
" 'people',\n", | |
" 'is',\n", | |
" 'evident',\n", | |
" 'as',\n", | |
" 'even',\n", | |
" 'friends',\n", | |
" ',',\n", | |
" 'they',\n", | |
" 'address',\n", | |
" 'each',\n", | |
" 'other',\n", | |
" 'using',\n", | |
" 'a',\n", | |
" 'formal',\n", | |
" 'you',\n", | |
" '.',\n", | |
" 'In',\n", | |
" '1984,while',\n", | |
" 'making',\n", | |
" \"L'amour\",\n", | |
" 'à',\n", | |
" 'mort',\n", | |
" ',',\n", | |
" 'Resnais',\n", | |
" 'dealt',\n", | |
" 'with',\n", | |
" 'time',\n", | |
" ',',\n", | |
" 'memory',\n", | |
" 'and',\n", | |
" 'space',\n", | |
" 'to',\n", | |
" 'unravel',\n", | |
" 'the',\n", | |
" 'mysteries',\n", | |
" 'of',\n", | |
" 'a',\n", | |
" 'fundamental',\n", | |
" 'question',\n", | |
" 'of',\n", | |
" 'human',\n", | |
" 'existence',\n", | |
" ':',\n", | |
" 'Is',\n", | |
" 'love',\n", | |
" 'stronger',\n", | |
" 'than',\n", | |
" 'death',\n", | |
" '?',\n", | |
" 'It',\n", | |
" 'was',\n", | |
" '16',\n", | |
" 'years',\n", | |
" 'ago',\n", | |
" 'in',\n", | |
" '1968',\n", | |
" 'that',\n", | |
" 'Resnais',\n", | |
" 'made',\n", | |
" 'a',\n", | |
" 'somewhat',\n", | |
" 'similar',\n", | |
" 'film',\n", | |
" 'Je',\n", | |
" \"t'aime\",\n", | |
" 'Je',\n", | |
" \"t'aime\",\n", | |
" 'which',\n", | |
" 'was',\n", | |
" 'also',\n", | |
" 'about',\n", | |
" 'love',\n", | |
" 'and',\n", | |
" 'memories',\n", | |
" '.',\n", | |
" 'Message',\n", | |
" 'of',\n", | |
" 'this',\n", | |
" 'film',\n", | |
" 'is',\n", | |
" 'loud',\n", | |
" 'and',\n", | |
" 'clear',\n", | |
" ':',\n", | |
" 'true',\n", | |
" 'and',\n", | |
" 'deep',\n", | |
" 'love',\n", | |
" 'can',\n", | |
" 'even',\n", | |
" 'put',\n", | |
" 'science',\n", | |
" 'to',\n", | |
" 'shame',\n", | |
" 'as',\n", | |
" 'dead',\n", | |
" 'lovers',\n", | |
" 'regain',\n", | |
" 'their',\n", | |
" 'lost',\n", | |
" 'lives',\n", | |
" 'leaving',\n", | |
" 'doctors',\n", | |
" 'to',\n", | |
" 'care',\n", | |
" 'for',\n", | |
" 'their',\n", | |
" 'reputation',\n", | |
" '.',\n", | |
" \"L'amour\",\n", | |
" 'à',\n", | |
" 'mort',\n", | |
" 'is',\n", | |
" 'like',\n", | |
" 'a',\n", | |
" 'game',\n", | |
" 'which',\n", | |
" 'is',\n", | |
" 'not',\n", | |
" 'at',\n", | |
" 'all',\n", | |
" 'didactic',\n", | |
" '.',\n", | |
" 'It',\n", | |
" 'is',\n", | |
" 'a',\n", | |
" 'film',\n", | |
" 'in',\n", | |
" 'which',\n", | |
" 'the',\n", | |
" 'musical',\n", | |
" 'score',\n", | |
" 'is',\n", | |
" 'in',\n", | |
" 'perfect',\n", | |
" 'tandem',\n", | |
" 'with',\n", | |
" 'its',\n", | |
" 'images',\n", | |
" '.',\n", | |
" 'This',\n", | |
" 'is',\n", | |
" 'one',\n", | |
" 'of',\n", | |
" 'the',\n", | |
" 'reasons',\n", | |
" 'why',\n", | |
" 'this',\n", | |
" 'film',\n", | |
" 'can',\n", | |
" 'easily',\n", | |
" 'be',\n", | |
" 'grasped',\n", | |
" '.']" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tokenizer_cased(proper_train_dataset[0][0])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Note that the above code cell output just consists of tokens (already converted to strings) found by spaCy. You learned how to do precisely this back in the first demo of this course!\n", | |
"\n", | |
"You'll notice that the above tokens keeps track of whether words are uppercase or lowercase. We will be using a word embedding that is \"uncased\" so we will convert everything to lowercase with the follow function." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def tokenizer(text):\n", | |
" return [token.lower() for token in tokenizer_cased(text)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"proper_train_dataset_as_tokens_without_labels = [tokenizer(text) for text, label in proper_train_dataset]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['master',\n", | |
" 'cinéaste',\n", | |
" 'alain',\n", | |
" 'resnais',\n", | |
" 'likes',\n", | |
" 'to',\n", | |
" 'work',\n", | |
" 'with',\n", | |
" 'those',\n", | |
" 'actors',\n", | |
" 'who',\n", | |
" 'are',\n", | |
" 'a',\n", | |
" 'part',\n", | |
" 'of',\n", | |
" 'his',\n", | |
" 'family',\n", | |
" '.',\n", | |
" 'in',\n", | |
" 'this',\n", | |
" 'film',\n", | |
" 'too',\n", | |
" 'we',\n", | |
" 'see',\n", | |
" 'resnais',\n", | |
" \"'\",\n", | |
" 'family',\n", | |
" 'members',\n", | |
" 'like',\n", | |
" 'pierre',\n", | |
" 'arditi',\n", | |
" ',',\n", | |
" 'sabine',\n", | |
" 'azema',\n", | |
" ',',\n", | |
" 'andré',\n", | |
" 'dussolier',\n", | |
" 'and',\n", | |
" 'fanny',\n", | |
" 'ardant',\n", | |
" 'dealing',\n", | |
" 'with',\n", | |
" 'serious',\n", | |
" 'themes',\n", | |
" 'like',\n", | |
" 'death',\n", | |
" ',',\n", | |
" 'religion',\n", | |
" ',',\n", | |
" 'suicide',\n", | |
" ',',\n", | |
" 'love',\n", | |
" 'and',\n", | |
" 'their',\n", | |
" 'overall',\n", | |
" 'implications',\n", | |
" 'on',\n", | |
" 'our',\n", | |
" 'daily',\n", | |
" 'lives',\n", | |
" '.',\n", | |
" 'the',\n", | |
" 'formal',\n", | |
" 'nature',\n", | |
" 'of',\n", | |
" 'relationship',\n", | |
" 'shared',\n", | |
" 'by',\n", | |
" 'these',\n", | |
" 'people',\n", | |
" 'is',\n", | |
" 'evident',\n", | |
" 'as',\n", | |
" 'even',\n", | |
" 'friends',\n", | |
" ',',\n", | |
" 'they',\n", | |
" 'address',\n", | |
" 'each',\n", | |
" 'other',\n", | |
" 'using',\n", | |
" 'a',\n", | |
" 'formal',\n", | |
" 'you',\n", | |
" '.',\n", | |
" 'in',\n", | |
" '1984,while',\n", | |
" 'making',\n", | |
" \"l'amour\",\n", | |
" 'à',\n", | |
" 'mort',\n", | |
" ',',\n", | |
" 'resnais',\n", | |
" 'dealt',\n", | |
" 'with',\n", | |
" 'time',\n", | |
" ',',\n", | |
" 'memory',\n", | |
" 'and',\n", | |
" 'space',\n", | |
" 'to',\n", | |
" 'unravel',\n", | |
" 'the',\n", | |
" 'mysteries',\n", | |
" 'of',\n", | |
" 'a',\n", | |
" 'fundamental',\n", | |
" 'question',\n", | |
" 'of',\n", | |
" 'human',\n", | |
" 'existence',\n", | |
" ':',\n", | |
" 'is',\n", | |
" 'love',\n", | |
" 'stronger',\n", | |
" 'than',\n", | |
" 'death',\n", | |
" '?',\n", | |
" 'it',\n", | |
" 'was',\n", | |
" '16',\n", | |
" 'years',\n", | |
" 'ago',\n", | |
" 'in',\n", | |
" '1968',\n", | |
" 'that',\n", | |
" 'resnais',\n", | |
" 'made',\n", | |
" 'a',\n", | |
" 'somewhat',\n", | |
" 'similar',\n", | |
" 'film',\n", | |
" 'je',\n", | |
" \"t'aime\",\n", | |
" 'je',\n", | |
" \"t'aime\",\n", | |
" 'which',\n", | |
" 'was',\n", | |
" 'also',\n", | |
" 'about',\n", | |
" 'love',\n", | |
" 'and',\n", | |
" 'memories',\n", | |
" '.',\n", | |
" 'message',\n", | |
" 'of',\n", | |
" 'this',\n", | |
" 'film',\n", | |
" 'is',\n", | |
" 'loud',\n", | |
" 'and',\n", | |
" 'clear',\n", | |
" ':',\n", | |
" 'true',\n", | |
" 'and',\n", | |
" 'deep',\n", | |
" 'love',\n", | |
" 'can',\n", | |
" 'even',\n", | |
" 'put',\n", | |
" 'science',\n", | |
" 'to',\n", | |
" 'shame',\n", | |
" 'as',\n", | |
" 'dead',\n", | |
" 'lovers',\n", | |
" 'regain',\n", | |
" 'their',\n", | |
" 'lost',\n", | |
" 'lives',\n", | |
" 'leaving',\n", | |
" 'doctors',\n", | |
" 'to',\n", | |
" 'care',\n", | |
" 'for',\n", | |
" 'their',\n", | |
" 'reputation',\n", | |
" '.',\n", | |
" \"l'amour\",\n", | |
" 'à',\n", | |
" 'mort',\n", | |
" 'is',\n", | |
" 'like',\n", | |
" 'a',\n", | |
" 'game',\n", | |
" 'which',\n", | |
" 'is',\n", | |
" 'not',\n", | |
" 'at',\n", | |
" 'all',\n", | |
" 'didactic',\n", | |
" '.',\n", | |
" 'it',\n", | |
" 'is',\n", | |
" 'a',\n", | |
" 'film',\n", | |
" 'in',\n", | |
" 'which',\n", | |
" 'the',\n", | |
" 'musical',\n", | |
" 'score',\n", | |
" 'is',\n", | |
" 'in',\n", | |
" 'perfect',\n", | |
" 'tandem',\n", | |
" 'with',\n", | |
" 'its',\n", | |
" 'images',\n", | |
" '.',\n", | |
" 'this',\n", | |
" 'is',\n", | |
" 'one',\n", | |
" 'of',\n", | |
" 'the',\n", | |
" 'reasons',\n", | |
" 'why',\n", | |
" 'this',\n", | |
" 'film',\n", | |
" 'can',\n", | |
" 'easily',\n", | |
" 'be',\n", | |
" 'grasped',\n", | |
" '.']" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"proper_train_dataset_as_tokens_without_labels[0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We now build the vocabulary. We already saw how to manually do this with spaCy in the second demo of this course! `torchtext` provides a helper function to do this automatically for us." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torchtext.vocab import build_vocab_from_iterator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vocab = build_vocab_from_iterator(proper_train_dataset_as_tokens_without_labels,\n", | |
" specials=[\"<unk>\"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vocab.set_default_index(vocab[\"<unk>\"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[1259,\n", | |
" 59266,\n", | |
" 11261,\n", | |
" 16475,\n", | |
" 1225,\n", | |
" 7,\n", | |
" 171,\n", | |
" 20,\n", | |
" 162,\n", | |
" 169,\n", | |
" 43,\n", | |
" 31,\n", | |
" 5,\n", | |
" 186,\n", | |
" 6,\n", | |
" 33,\n", | |
" 233,\n", | |
" 3,\n", | |
" 10,\n", | |
" 12,\n", | |
" 24,\n", | |
" 110,\n", | |
" 80,\n", | |
" 77,\n", | |
" 16475,\n", | |
" 53,\n", | |
" 233,\n", | |
" 1068,\n", | |
" 45,\n", | |
" 6177,\n", | |
" 54976,\n", | |
" 2,\n", | |
" 36409,\n", | |
" 55541,\n", | |
" 2,\n", | |
" 15191,\n", | |
" 29297,\n", | |
" 4,\n", | |
" 12106,\n", | |
" 32513,\n", | |
" 1943,\n", | |
" 20,\n", | |
" 614,\n", | |
" 1318,\n", | |
" 45,\n", | |
" 361,\n", | |
" 2,\n", | |
" 2310,\n", | |
" 2,\n", | |
" 1752,\n", | |
" 2,\n", | |
" 131,\n", | |
" 4,\n", | |
" 78,\n", | |
" 579,\n", | |
" 10754,\n", | |
" 26,\n", | |
" 281,\n", | |
" 3045,\n", | |
" 474,\n", | |
" 3,\n", | |
" 1,\n", | |
" 11373,\n", | |
" 887,\n", | |
" 6,\n", | |
" 646,\n", | |
" 5281,\n", | |
" 40,\n", | |
" 149,\n", | |
" 94,\n", | |
" 8,\n", | |
" 3721,\n", | |
" 19,\n", | |
" 67,\n", | |
" 379,\n", | |
" 2,\n", | |
" 39,\n", | |
" 5559,\n", | |
" 267,\n", | |
" 96,\n", | |
" 787,\n", | |
" 5,\n", | |
" 11373,\n", | |
" 25,\n", | |
" 3,\n", | |
" 10,\n", | |
" 52626,\n", | |
" 247,\n", | |
" 26695,\n", | |
" 13321,\n", | |
" 35429,\n", | |
" 2,\n", | |
" 16475,\n", | |
" 3345,\n", | |
" 20,\n", | |
" 68,\n", | |
" 2,\n", | |
" 1897,\n", | |
" 4,\n", | |
" 844,\n", | |
" 7,\n", | |
" 8888,\n", | |
" 1,\n", | |
" 4146,\n", | |
" 6,\n", | |
" 5,\n", | |
" 8964,\n", | |
" 905,\n", | |
" 6,\n", | |
" 415,\n", | |
" 2081,\n", | |
" 93,\n", | |
" 8,\n", | |
" 131,\n", | |
" 3535,\n", | |
" 85,\n", | |
" 361,\n", | |
" 60,\n", | |
" 9,\n", | |
" 18,\n", | |
" 3207,\n", | |
" 168,\n", | |
" 617,\n", | |
" 10,\n", | |
" 5174,\n", | |
" 13,\n", | |
" 16475,\n", | |
" 104,\n", | |
" 5,\n", | |
" 631,\n", | |
" 730,\n", | |
" 24,\n", | |
" 7627,\n", | |
" 8295,\n", | |
" 7627,\n", | |
" 8295,\n", | |
" 72,\n", | |
" 18,\n", | |
" 99,\n", | |
" 51,\n", | |
" 131,\n", | |
" 4,\n", | |
" 1851,\n", | |
" 3,\n", | |
" 758,\n", | |
" 6,\n", | |
" 12,\n", | |
" 24,\n", | |
" 8,\n", | |
" 1313,\n", | |
" 4,\n", | |
" 790,\n", | |
" 93,\n", | |
" 304,\n", | |
" 4,\n", | |
" 929,\n", | |
" 131,\n", | |
" 70,\n", | |
" 67,\n", | |
" 282,\n", | |
" 1027,\n", | |
" 7,\n", | |
" 930,\n", | |
" 19,\n", | |
" 358,\n", | |
" 1799,\n", | |
" 9250,\n", | |
" 78,\n", | |
" 430,\n", | |
" 474,\n", | |
" 1200,\n", | |
" 5721,\n", | |
" 7,\n", | |
" 468,\n", | |
" 21,\n", | |
" 78,\n", | |
" 2703,\n", | |
" 3,\n", | |
" 26695,\n", | |
" 13321,\n", | |
" 35429,\n", | |
" 8,\n", | |
" 45,\n", | |
" 5,\n", | |
" 501,\n", | |
" 72,\n", | |
" 8,\n", | |
" 30,\n", | |
" 38,\n", | |
" 37,\n", | |
" 16907,\n", | |
" 3,\n", | |
" 9,\n", | |
" 8,\n", | |
" 5,\n", | |
" 24,\n", | |
" 10,\n", | |
" 72,\n", | |
" 1,\n", | |
" 638,\n", | |
" 626,\n", | |
" 8,\n", | |
" 10,\n", | |
" 425,\n", | |
" 21189,\n", | |
" 20,\n", | |
" 105,\n", | |
" 1241,\n", | |
" 3,\n", | |
" 12,\n", | |
" 8,\n", | |
" 36,\n", | |
" 6,\n", | |
" 1,\n", | |
" 1037,\n", | |
" 152,\n", | |
" 12,\n", | |
" 24,\n", | |
" 70,\n", | |
" 727,\n", | |
" 35,\n", | |
" 19175,\n", | |
" 3]" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# we can now convert any piece of text first into tokens and then from tokens into indices into the vocabulary\n", | |
"vocab(tokenizer(proper_train_dataset[0][0]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'master'" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# we can also look up what any specific word index refers to\n", | |
"vocab.lookup_token(1259)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"proper_train_encoded = [vocab(tokens) for tokens in proper_train_dataset_as_tokens_without_labels]\n", | |
"# note that another way to have written the above line is to instead write the line below (but this would repeat the work of tokenization which we already did):\n", | |
"# proper_train_encoded = [vocab(tokenizer_lower_case(text)) for text, label in proper_train_dataset]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[1259,\n", | |
" 59266,\n", | |
" 11261,\n", | |
" 16475,\n", | |
" 1225,\n", | |
" 7,\n", | |
" 171,\n", | |
" 20,\n", | |
" 162,\n", | |
" 169,\n", | |
" 43,\n", | |
" 31,\n", | |
" 5,\n", | |
" 186,\n", | |
" 6,\n", | |
" 33,\n", | |
" 233,\n", | |
" 3,\n", | |
" 10,\n", | |
" 12,\n", | |
" 24,\n", | |
" 110,\n", | |
" 80,\n", | |
" 77,\n", | |
" 16475,\n", | |
" 53,\n", | |
" 233,\n", | |
" 1068,\n", | |
" 45,\n", | |
" 6177,\n", | |
" 54976,\n", | |
" 2,\n", | |
" 36409,\n", | |
" 55541,\n", | |
" 2,\n", | |
" 15191,\n", | |
" 29297,\n", | |
" 4,\n", | |
" 12106,\n", | |
" 32513,\n", | |
" 1943,\n", | |
" 20,\n", | |
" 614,\n", | |
" 1318,\n", | |
" 45,\n", | |
" 361,\n", | |
" 2,\n", | |
" 2310,\n", | |
" 2,\n", | |
" 1752,\n", | |
" 2,\n", | |
" 131,\n", | |
" 4,\n", | |
" 78,\n", | |
" 579,\n", | |
" 10754,\n", | |
" 26,\n", | |
" 281,\n", | |
" 3045,\n", | |
" 474,\n", | |
" 3,\n", | |
" 1,\n", | |
" 11373,\n", | |
" 887,\n", | |
" 6,\n", | |
" 646,\n", | |
" 5281,\n", | |
" 40,\n", | |
" 149,\n", | |
" 94,\n", | |
" 8,\n", | |
" 3721,\n", | |
" 19,\n", | |
" 67,\n", | |
" 379,\n", | |
" 2,\n", | |
" 39,\n", | |
" 5559,\n", | |
" 267,\n", | |
" 96,\n", | |
" 787,\n", | |
" 5,\n", | |
" 11373,\n", | |
" 25,\n", | |
" 3,\n", | |
" 10,\n", | |
" 52626,\n", | |
" 247,\n", | |
" 26695,\n", | |
" 13321,\n", | |
" 35429,\n", | |
" 2,\n", | |
" 16475,\n", | |
" 3345,\n", | |
" 20,\n", | |
" 68,\n", | |
" 2,\n", | |
" 1897,\n", | |
" 4,\n", | |
" 844,\n", | |
" 7,\n", | |
" 8888,\n", | |
" 1,\n", | |
" 4146,\n", | |
" 6,\n", | |
" 5,\n", | |
" 8964,\n", | |
" 905,\n", | |
" 6,\n", | |
" 415,\n", | |
" 2081,\n", | |
" 93,\n", | |
" 8,\n", | |
" 131,\n", | |
" 3535,\n", | |
" 85,\n", | |
" 361,\n", | |
" 60,\n", | |
" 9,\n", | |
" 18,\n", | |
" 3207,\n", | |
" 168,\n", | |
" 617,\n", | |
" 10,\n", | |
" 5174,\n", | |
" 13,\n", | |
" 16475,\n", | |
" 104,\n", | |
" 5,\n", | |
" 631,\n", | |
" 730,\n", | |
" 24,\n", | |
" 7627,\n", | |
" 8295,\n", | |
" 7627,\n", | |
" 8295,\n", | |
" 72,\n", | |
" 18,\n", | |
" 99,\n", | |
" 51,\n", | |
" 131,\n", | |
" 4,\n", | |
" 1851,\n", | |
" 3,\n", | |
" 758,\n", | |
" 6,\n", | |
" 12,\n", | |
" 24,\n", | |
" 8,\n", | |
" 1313,\n", | |
" 4,\n", | |
" 790,\n", | |
" 93,\n", | |
" 304,\n", | |
" 4,\n", | |
" 929,\n", | |
" 131,\n", | |
" 70,\n", | |
" 67,\n", | |
" 282,\n", | |
" 1027,\n", | |
" 7,\n", | |
" 930,\n", | |
" 19,\n", | |
" 358,\n", | |
" 1799,\n", | |
" 9250,\n", | |
" 78,\n", | |
" 430,\n", | |
" 474,\n", | |
" 1200,\n", | |
" 5721,\n", | |
" 7,\n", | |
" 468,\n", | |
" 21,\n", | |
" 78,\n", | |
" 2703,\n", | |
" 3,\n", | |
" 26695,\n", | |
" 13321,\n", | |
" 35429,\n", | |
" 8,\n", | |
" 45,\n", | |
" 5,\n", | |
" 501,\n", | |
" 72,\n", | |
" 8,\n", | |
" 30,\n", | |
" 38,\n", | |
" 37,\n", | |
" 16907,\n", | |
" 3,\n", | |
" 9,\n", | |
" 8,\n", | |
" 5,\n", | |
" 24,\n", | |
" 10,\n", | |
" 72,\n", | |
" 1,\n", | |
" 638,\n", | |
" 626,\n", | |
" 8,\n", | |
" 10,\n", | |
" 425,\n", | |
" 21189,\n", | |
" 20,\n", | |
" 105,\n", | |
" 1241,\n", | |
" 3,\n", | |
" 12,\n", | |
" 8,\n", | |
" 36,\n", | |
" 6,\n", | |
" 1,\n", | |
" 1037,\n", | |
" 152,\n", | |
" 12,\n", | |
" 24,\n", | |
" 70,\n", | |
" 727,\n", | |
" 35,\n", | |
" 19175,\n", | |
" 3]" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"proper_train_encoded[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"['master', 'cinéaste', 'alain', 'resnais', 'likes', 'to', 'work', 'with', 'those', 'actors', 'who', 'are', 'a', 'part', 'of', 'his', 'family', '.', 'in', 'this', 'film', 'too', 'we', 'see', 'resnais', \"'\", 'family', 'members', 'like', 'pierre', 'arditi', ',', 'sabine', 'azema', ',', 'andré', 'dussolier', 'and', 'fanny', 'ardant', 'dealing', 'with', 'serious', 'themes', 'like', 'death', ',', 'religion', ',', 'suicide', ',', 'love', 'and', 'their', 'overall', 'implications', 'on', 'our', 'daily', 'lives', '.', 'the', 'formal', 'nature', 'of', 'relationship', 'shared', 'by', 'these', 'people', 'is', 'evident', 'as', 'even', 'friends', ',', 'they', 'address', 'each', 'other', 'using', 'a', 'formal', 'you', '.', 'in', '1984,while', 'making', \"l'amour\", 'à', 'mort', ',', 'resnais', 'dealt', 'with', 'time', ',', 'memory', 'and', 'space', 'to', 'unravel', 'the', 'mysteries', 'of', 'a', 'fundamental', 'question', 'of', 'human', 'existence', ':', 'is', 'love', 'stronger', 'than', 'death', '?', 'it', 'was', '16', 'years', 'ago', 'in', '1968', 'that', 'resnais', 'made', 'a', 'somewhat', 'similar', 'film', 'je', \"t'aime\", 'je', \"t'aime\", 'which', 'was', 'also', 'about', 'love', 'and', 'memories', '.', 'message', 'of', 'this', 'film', 'is', 'loud', 'and', 'clear', ':', 'true', 'and', 'deep', 'love', 'can', 'even', 'put', 'science', 'to', 'shame', 'as', 'dead', 'lovers', 'regain', 'their', 'lost', 'lives', 'leaving', 'doctors', 'to', 'care', 'for', 'their', 'reputation', '.', \"l'amour\", 'à', 'mort', 'is', 'like', 'a', 'game', 'which', 'is', 'not', 'at', 'all', 'didactic', '.', 'it', 'is', 'a', 'film', 'in', 'which', 'the', 'musical', 'score', 'is', 'in', 'perfect', 'tandem', 'with', 'its', 'images', '.', 'this', 'is', 'one', 'of', 'the', 'reasons', 'why', 'this', 'film', 'can', 'easily', 'be', 'grasped', '.']\n" | |
] | |
} | |
], | |
"source": [ | |
"# we can reconstruct any original review from the encoded version of the review\n", | |
"print([vocab.lookup_token(word_idx) for word_idx in proper_train_encoded[0]])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"proper_train_labels = [label for text, label in proper_train_dataset]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"val_encoded = [vocab(tokenizer(text)) for text, label in val_dataset]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"val_labels = [label for text, label in val_dataset]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"proper_train_dataset_encoded = list(zip(proper_train_encoded, proper_train_labels))\n", | |
"val_dataset_encoded = list(zip(val_encoded, val_labels))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Setting up a recurrent neural net for sentiment analysis that uses pre-trained word embeddings" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We first load in pre-trained GloVe embeddings only for tokens that we encountered in the proper training data. Note that these embeddings are for \"uncased\" words (i.e., we should convert the tokens to lowercase first)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torchtext.vocab import GloVe\n", | |
"pretrained_embedding = GloVe(name='6B', dim=100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torchtext.vocab.vectors.GloVe" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"type(pretrained_embedding)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([ 0.2309, 0.2828, 0.6318, -0.5941, -0.5860, 0.6326, 0.2440, -0.1411,\n", | |
" 0.0608, -0.7898, -0.2910, 0.1429, 0.7227, 0.2043, 0.1407, 0.9876,\n", | |
" 0.5253, 0.0975, 0.8822, 0.5122, 0.4020, 0.2117, -0.0131, -0.7162,\n", | |
" 0.5539, 1.1452, -0.8804, -0.5022, -0.2281, 0.0239, 0.1072, 0.0837,\n", | |
" 0.5501, 0.5848, 0.7582, 0.4571, -0.2800, 0.2522, 0.6896, -0.6097,\n", | |
" 0.1958, 0.0442, -0.3114, -0.6883, -0.2272, 0.4618, -0.7716, 0.1021,\n", | |
" 0.5564, 0.0674, -0.5721, 0.2374, 0.4717, 0.8277, -0.2926, -1.3422,\n", | |
" -0.0993, 0.2814, 0.4160, 0.1058, 0.6220, 0.8950, -0.2345, 0.5135,\n", | |
" 0.9938, 1.1846, -0.1636, 0.2065, 0.7385, 0.2406, -0.9647, 0.1348,\n", | |
" -0.0072, 0.3302, -0.1236, 0.2719, -0.4095, 0.0219, -0.6069, 0.4076,\n", | |
" 0.1957, -0.4180, 0.1864, -0.0327, -0.7857, -0.1385, 0.0440, -0.0844,\n", | |
" 0.0491, 0.2410, 0.4527, -0.1868, 0.4618, 0.0891, -0.1819, -0.0152,\n", | |
" -0.7368, -0.1453, 0.1510, -0.7149])" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pretrained_embedding['cat']\n", | |
"# note that if you ask for a word embedding for a word that GloVe does not keep\n", | |
"# track of, you'll get all zeros" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"embedding_matrix = torch.zeros(len(vocab), pretrained_embedding.dim)\n", | |
"for i, token in enumerate(vocab.lookup_tokens(range(len(vocab)))):\n", | |
" embedding_matrix[i] = pretrained_embedding[token]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The next code cell constructs a PyTorch recurrent neural net model for sentiment analysis. Unfortunately, the code involved doesn't readily work with ``nn.Sequential`` that we have previously been using (for somewhat complicated reasons; in short, it has to do with an input batch of time series possibly having varying lengths, and accounting for varying lengths properly is not easy to do with ``nn.Sequential``). We will instead build the PyTorch model using another standard approach by creating a Python class that inherits from the `nn.Module` class.\n", | |
"\n", | |
"To illustrate how this works, here was how we created the multilayer perceptron model for MNIST digits in the previous demo:\n", | |
"\n", | |
"```\n", | |
"deeper_model = nn.Sequential(nn.Flatten(),\n", | |
" nn.Linear(in_features=784, out_features=512),\n", | |
" nn.ReLU(),\n", | |
" nn.Linear(in_features=512, out_features=10))\n", | |
"```\n", | |
"\n", | |
"An alternative way to code the same model is as follows:\n", | |
"\n", | |
"```\n", | |
"class DeeperModel(nn.Module):\n", | |
" def __init__(self, num_in_features, num_intermediate_features, num_out_features):\n", | |
" super().__init__()\n", | |
" self.flatten = nn.Flatten()\n", | |
" self.linear1 = nn.Linear(num_in_features, num_intermediate_features)\n", | |
" self.relu = nn.ReLU()\n", | |
" self.linear2 = nn.Linear(num_intermediate_features, num_out_features)\n", | |
"\n", | |
" def forward(self, inputs):\n", | |
" flatten_output = self.flatten(inputs)\n", | |
" linear1_output = self.linear1(flatten_output)\n", | |
" relu_output = self.relu(linear1_output)\n", | |
" linear2_output = self.linear2(relu_output)\n", | |
" return linear2_output\n", | |
"\n", | |
"deeper_model = DeeperModel(784, 512, 10)\n", | |
"```\n", | |
"\n", | |
"**Importantly, in the above code, the `forward` function specifies how the neural net processes input data. Note that the only argument it takes (aside from `self`) is the input data (`inputs`), which for MNIST digits we saw will be of the format (batch size, 1, 28, 28). Consequently, the example input data batch supplied to the `summary` function only needs this 4D table.**\n", | |
"\n", | |
"The code below creates the PyTorch neural net model corresponding to the architecture:\n", | |
"\n", | |
"1. `Embedding` layer (for every time series: convert the word index at every time step into a 100-dimensional GloVe word embedding)\n", | |
"2. `LSTM` layer with 32 output nodes (for every time series: put the time series through the LSTM's `for` loop and grab only the last time step's output, which has 32 numbers)\n", | |
"3. `Linear` layer with 2 output nodes (now every input data point to the linear layer is just a 1D table of 32 numbers, which this linear layer converts to 2 output numbers corresponding to the two classes)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class EmbeddingLSTMLinearModel(nn.Module):\n", | |
" def __init__(self, embedding_matrix, num_lstm_output_nodes, num_final_output_nodes):\n", | |
" super().__init__()\n", | |
" self.embedding_layer = nn.Embedding.from_pretrained(embedding_matrix)\n", | |
" self.lstm_layer = nn.LSTM(embedding_matrix.shape[1], num_lstm_output_nodes)\n", | |
" self.linear_layer = nn.Linear(num_lstm_output_nodes, num_final_output_nodes)\n", | |
"\n", | |
" def forward(self, text_encodings, lengths):\n", | |
" embeddings = self.embedding_layer(text_encodings)\n", | |
"\n", | |
" rnn_last_time_step_outputs = \\\n", | |
" UDA_get_rnn_last_time_step_outputs(embeddings, lengths, self.lstm_layer)\n", | |
"\n", | |
" return self.linear_layer(rnn_last_time_step_outputs)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Notice that the `forward` function takes _two_ inputs (aside from `self`): `text_encodings` and `lengths`.** Keep in mind that `forward` takes in a batch of input time series, and these input time series could have different lengths. These different lengths are precisely what is stored in `lengths` as a 1D table of integers. Meanwhile, `text_encodings` is a 2D table where the number of rows is the maximum number of time steps in the input batch of time series, and the number of columns is the number of time series in the input batch. See the lecture slides for how `text_encodings` gets filled in (basically we pad all the time series in the batch to be of the same length as the longest time series; the padded entries will of course get ignored by the LSTM layer since it will know the correct length to use per time series). **When we give this neural net model an example data batch using the `summary` function, we need to specify two inputs corresponding to `text_encodings` and `lengths`.**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"simple_lstm_model = EmbeddingLSTMLinearModel(embedding_matrix, 32, 2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"==========================================================================================\n", | |
"Layer (type:depth-idx) Output Shape Param #\n", | |
"==========================================================================================\n", | |
"EmbeddingLSTMLinearModel [5, 2] --\n", | |
"├─Embedding: 1-1 [7, 5, 100] (9,066,700)\n", | |
"├─LSTM: 1-2 [18, 32] 17,152\n", | |
"├─Linear: 1-3 [5, 2] 66\n", | |
"==========================================================================================\n", | |
"Total params: 9,083,918\n", | |
"Trainable params: 17,218\n", | |
"Non-trainable params: 9,066,700\n", | |
"Total mult-adds (Units.MEGABYTES): 73.35\n", | |
"==========================================================================================\n", | |
"Input size (MB): 0.00\n", | |
"Forward/backward pass size (MB): 0.03\n", | |
"Params size (MB): 36.34\n", | |
"Estimated Total Size (MB): 36.37\n", | |
"==========================================================================================" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# example where there are 5 input time series of lengths 3, 2, 5, 1, 7;\n", | |
"# we specify these time series using a 2D table that is padded and a\n", | |
"# 1D table of lengths (see lecture slides for details)\n", | |
"summary(simple_lstm_model,\n", | |
" input_data=[torch.zeros((7, 5), dtype=torch.long),\n", | |
" torch.tensor([3, 2, 5, 1, 7], dtype=torch.long)])\n", | |
"\n", | |
"# note: the LSTM's output is in a compressed format (called a \"packed sequence\")\n", | |
"# that appears to put all the outputs of all the time series together (the output\n", | |
"# shape in this case appears to be 18 by 32) but it actually does keep track of\n", | |
"# which of the time steps correspond to which input time series (i.e., it knows\n", | |
"# that 3 of the rows correspond to the 0-th time series, 2 of the rows correspond\n", | |
"# to the 1st time series, 5 of the rows correspond to the 2nd time series, 1 row\n", | |
"# corresponds to the 3rd time series, and 7 rows correspond to the 4th time\n", | |
"# series); my helper code automatically maps these rows back to the correct\n", | |
"# format so that the final linear layer recognizes that there are 5 input data\n", | |
"# points" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"os.makedirs('./saved_model_checkpoints', exist_ok=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 0%| | 0/30 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.6021\n", | |
" Validation accuracy: 0.6048\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 3%|████▋ | 1/30 [20:28<9:53:47, 1228.52s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.7117\n", | |
" Validation accuracy: 0.7176\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 7%|█████████▍ | 2/30 [38:56<9:00:10, 1157.51s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.7664\n", | |
" Validation accuracy: 0.7738\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 10%|██████████████ | 3/30 [57:23<8:30:31, 1134.49s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.7790\n", | |
" Validation accuracy: 0.7804\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 13%|██████████████████▌ | 4/30 [1:15:49<8:06:42, 1123.18s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8116\n", | |
" Validation accuracy: 0.8112\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 17%|███████████████████████▏ | 5/30 [1:34:01<7:43:22, 1112.11s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8180\n", | |
" Validation accuracy: 0.8110\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 20%|███████████████████████████▊ | 6/30 [1:52:29<7:24:11, 1110.47s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8469\n", | |
" Validation accuracy: 0.8390\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 23%|████████████████████████████████▍ | 7/30 [2:10:50<7:04:30, 1107.43s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8613\n", | |
" Validation accuracy: 0.8534\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 27%|█████████████████████████████████████ | 8/30 [2:29:11<6:45:22, 1105.59s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8655\n", | |
" Validation accuracy: 0.8602\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 30%|█████████████████████████████████████████▋ | 9/30 [2:47:35<6:26:42, 1104.88s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8715\n", | |
" Validation accuracy: 0.8658\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 33%|██████████████████████████████████████████████ | 10/30 [3:05:52<6:07:29, 1102.49s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8804\n", | |
" Validation accuracy: 0.8658\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 37%|██████████████████████████████████████████████████▌ | 11/30 [3:24:09<5:48:38, 1100.98s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8810\n", | |
" Validation accuracy: 0.8672\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 40%|███████████████████████████████████████████████████████▏ | 12/30 [3:42:40<5:31:11, 1103.96s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8847\n", | |
" Validation accuracy: 0.8686\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 43%|███████████████████████████████████████████████████████████▊ | 13/30 [4:01:07<5:13:00, 1104.72s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8923\n", | |
" Validation accuracy: 0.8734\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 47%|████████████████████████████████████████████████████████████████▍ | 14/30 [4:19:31<4:54:34, 1104.64s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8914\n", | |
" Validation accuracy: 0.8688\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 50%|█████████████████████████████████████████████████████████████████████ | 15/30 [4:37:58<4:36:20, 1105.38s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.9000\n", | |
" Validation accuracy: 0.8746\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 53%|█████████████████████████████████████████████████████████████████████████▌ | 16/30 [4:56:28<4:18:14, 1106.73s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.9010\n", | |
" Validation accuracy: 0.8794\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 57%|██████████████████████████████████████████████████████████████████████████████▏ | 17/30 [5:14:37<3:58:38, 1101.40s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.9050\n", | |
" Validation accuracy: 0.8784\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 60%|██████████████████████████████████████████████████████████████████████████████████▊ | 18/30 [5:32:52<3:39:53, 1099.50s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.9061\n", | |
" Validation accuracy: 0.8762\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 63%|███████████████████████████████████████████████████████████████████████████████████████▍ | 19/30 [5:51:04<3:21:08, 1097.15s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.9087\n", | |
" Validation accuracy: 0.8752\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 67%|████████████████████████████████████████████████████████████████████████████████████████████ | 20/30 [6:09:24<3:03:01, 1098.16s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8938\n", | |
" Validation accuracy: 0.8650\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 70%|████████████████████████████████████████████████████████████████████████████████████████████████▌ | 21/30 [6:27:40<2:44:35, 1097.31s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.9158\n", | |
" Validation accuracy: 0.8804\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 73%|█████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 22/30 [6:45:55<2:26:13, 1096.72s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8961\n", | |
" Validation accuracy: 0.8608\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 23/30 [7:04:17<2:08:07, 1098.23s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.9204\n", | |
" Validation accuracy: 0.8756\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 80%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 24/30 [7:22:33<1:49:46, 1097.75s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.9193\n", | |
" Validation accuracy: 0.8804\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 83%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 25/30 [7:40:50<1:31:26, 1097.28s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.9045\n", | |
" Validation accuracy: 0.8628\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 26/30 [7:59:06<1:13:07, 1096.93s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.9193\n", | |
" Validation accuracy: 0.8720\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 27/30 [8:17:31<54:58, 1099.40s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.9250\n", | |
" Validation accuracy: 0.8742\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 28/30 [8:35:59<36:44, 1102.06s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.8967\n", | |
" Validation accuracy: 0.8512\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 29/30 [8:54:23<18:22, 1102.75s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Proper training accuracy: 0.9267\n", | |
" Validation accuracy: 0.8710\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [9:12:57<00:00, 1105.93s/it]\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"num_epochs = 30 # during optimization, how many times we look at training data\n", | |
"batch_size = 128 # during optimization, how many training data to use at each step\n", | |
"learning_rate = 0.001 # during optimization, how much we nudge our solution at each step\n", | |
"\n", | |
"proper_train_accuracies, val_accuracies = \\\n", | |
" UDA_pytorch_classifier_fit(simple_lstm_model,\n", | |
" torch.optim.Adam(simple_lstm_model.parameters(),\n", | |
" lr=learning_rate),\n", | |
" nn.CrossEntropyLoss(), # includes softmax\n", | |
" proper_train_dataset_encoded, val_dataset_encoded,\n", | |
" num_epochs, batch_size,\n", | |
" rnn=True,\n", | |
" save_epoch_checkpoint_prefix='./saved_model_checkpoints/imdb_lstm')\n", | |
"\n", | |
"UDA_plot_train_val_accuracy_vs_epoch(proper_train_accuracies, val_accuracies)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"The model at the end of epoch 22 achieved the highest validation accuracy: 0.880400\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<All keys matched successfully>" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"best_epoch_idx = np.argmax(val_accuracies)\n", | |
"print('The model at the end of epoch %d achieved the highest validation accuracy: %f'\n", | |
" % (best_epoch_idx + 1, val_accuracies[best_epoch_idx]))\n", | |
"simple_lstm_model.load_state_dict(torch.load('./saved_model_checkpoints/imdb_lstm_epoch%d.pt' % (best_epoch_idx + 1)))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Finally evaluate on test data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_dataset = []\n", | |
"\n", | |
"for filename in sorted(glob('./data/aclImdb/test/pos/*.txt')):\n", | |
" with open(filename, 'r', encoding='utf-8') as f:\n", | |
" test_dataset.append((f.read(), 1)) # 1 means `positive` sentiment\n", | |
"\n", | |
"for filename in sorted(glob('./data/aclImdb/test/neg/*.txt')):\n", | |
" with open(filename, 'r', encoding='utf-8') as f:\n", | |
" test_dataset.append((f.read(), 0)) # 0 means `negative` sentiment" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_encoded = [vocab(tokenizer(text)) for text, label in test_dataset]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_labels = [label for text, label in test_dataset]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"predicted_test_labels = UDA_pytorch_classifier_predict(simple_lstm_model,\n", | |
" test_encoded,\n", | |
" rnn=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Test accuracy: 0.88348\n" | |
] | |
} | |
], | |
"source": [ | |
"print('Test accuracy:', UDA_compute_accuracy(predicted_test_labels, test_labels))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([1])" | |
] | |
}, | |
"execution_count": 39, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"UDA_pytorch_classifier_predict(simple_lstm_model,\n", | |
" [vocab(tokenizer('this movie rocks'))],\n", | |
" rnn=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([0])" | |
] | |
}, | |
"execution_count": 40, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"UDA_pytorch_classifier_predict(simple_lstm_model,\n", | |
" [vocab(tokenizer('this movie sucks'))],\n", | |
" rnn=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([0])" | |
] | |
}, | |
"execution_count": 41, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"UDA_pytorch_classifier_predict(simple_lstm_model,\n", | |
" [vocab(tokenizer('this sucks'))],\n", | |
" rnn=True)" | |
] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment