Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save manojps/ffd382ee97d173f14f9ac5075fe1d1af to your computer and use it in GitHub Desktop.
Save manojps/ffd382ee97d173f14f9ac5075fe1d1af to your computer and use it in GitHub Desktop.
ULMFiT for Airline Tweet Sentiment Analysis
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"# ULMFiT for Airline Tweet Sentiment Analysis"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook demonstrates how to apply a supervised ULMFiT model to \"Twitter US Airline Sentiment\" dataset available at https://www.kaggle.com/crowdflower/twitter-airline-sentiment#Tweets.csv"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Environment Setup "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# ! conda create -n fastai\n",
"# ! conda activate fastai\n",
"# ! conda install jupyter notebook\n",
"# ! conda install pytorch torchvision cudatoolkit=10.0 -c pytorch\n",
"# ! conda install nltk\n",
"# ! conda install pandas"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"Authour: Manoj Pravakar Saha\n",
"Email: hello@manojsaha.com\n",
"License: Apache License 2.0\n",
"\"\"\"\n",
"\n",
"import re\n",
"import os\n",
"from functools import partial\n",
"from collections import Counter\n",
"import string\n",
"import pandas as pd\n",
"import nltk\n",
"from nltk.corpus import wordnet\n",
"from fastai.text import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pre-processing Data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this step, we'll pre-process the data for feeding into the model. I am jumping directly to pre-processing, before Exploratory Data Analysis (EDA) for brevity. To ensure better model performce, we must perform EDA to unerstand the data prior to moving on to the model.\n",
"\n",
"For pre-processing I am using a subset of techniques discussed in the paper titled \"A Comparison of Pre-processing Techniques for Twitter Sentiment Analysis\". The code is available at https://github.com/Deffro/text-preprocessing-techniques. I am using the provided code with some minor modifications."
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>tweet_id</th>\n",
" <th>airline_sentiment</th>\n",
" <th>airline_sentiment_confidence</th>\n",
" <th>negativereason</th>\n",
" <th>negativereason_confidence</th>\n",
" <th>airline</th>\n",
" <th>airline_sentiment_gold</th>\n",
" <th>name</th>\n",
" <th>negativereason_gold</th>\n",
" <th>retweet_count</th>\n",
" <th>text</th>\n",
" <th>tweet_coord</th>\n",
" <th>tweet_created</th>\n",
" <th>tweet_location</th>\n",
" <th>user_timezone</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>570306133677760513</td>\n",
" <td>neutral</td>\n",
" <td>1.0000</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>Virgin America</td>\n",
" <td>NaN</td>\n",
" <td>cairdin</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>@VirginAmerica What @dhepburn said.</td>\n",
" <td>NaN</td>\n",
" <td>2015-02-24 11:35:52 -0800</td>\n",
" <td>NaN</td>\n",
" <td>Eastern Time (US &amp; Canada)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>570301130888122368</td>\n",
" <td>positive</td>\n",
" <td>0.3486</td>\n",
" <td>NaN</td>\n",
" <td>0.0000</td>\n",
" <td>Virgin America</td>\n",
" <td>NaN</td>\n",
" <td>jnardino</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>@VirginAmerica plus you've added commercials t...</td>\n",
" <td>NaN</td>\n",
" <td>2015-02-24 11:15:59 -0800</td>\n",
" <td>NaN</td>\n",
" <td>Pacific Time (US &amp; Canada)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>570301083672813571</td>\n",
" <td>neutral</td>\n",
" <td>0.6837</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>Virgin America</td>\n",
" <td>NaN</td>\n",
" <td>yvonnalynn</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>@VirginAmerica I didn't today... Must mean I n...</td>\n",
" <td>NaN</td>\n",
" <td>2015-02-24 11:15:48 -0800</td>\n",
" <td>Lets Play</td>\n",
" <td>Central Time (US &amp; Canada)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>570301031407624196</td>\n",
" <td>negative</td>\n",
" <td>1.0000</td>\n",
" <td>Bad Flight</td>\n",
" <td>0.7033</td>\n",
" <td>Virgin America</td>\n",
" <td>NaN</td>\n",
" <td>jnardino</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>@VirginAmerica it's really aggressive to blast...</td>\n",
" <td>NaN</td>\n",
" <td>2015-02-24 11:15:36 -0800</td>\n",
" <td>NaN</td>\n",
" <td>Pacific Time (US &amp; Canada)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>570300817074462722</td>\n",
" <td>negative</td>\n",
" <td>1.0000</td>\n",
" <td>Can't Tell</td>\n",
" <td>1.0000</td>\n",
" <td>Virgin America</td>\n",
" <td>NaN</td>\n",
" <td>jnardino</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>@VirginAmerica and it's a really big bad thing...</td>\n",
" <td>NaN</td>\n",
" <td>2015-02-24 11:14:45 -0800</td>\n",
" <td>NaN</td>\n",
" <td>Pacific Time (US &amp; Canada)</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" tweet_id airline_sentiment airline_sentiment_confidence \\\n",
"0 570306133677760513 neutral 1.0000 \n",
"1 570301130888122368 positive 0.3486 \n",
"2 570301083672813571 neutral 0.6837 \n",
"3 570301031407624196 negative 1.0000 \n",
"4 570300817074462722 negative 1.0000 \n",
"\n",
" negativereason negativereason_confidence airline \\\n",
"0 NaN NaN Virgin America \n",
"1 NaN 0.0000 Virgin America \n",
"2 NaN NaN Virgin America \n",
"3 Bad Flight 0.7033 Virgin America \n",
"4 Can't Tell 1.0000 Virgin America \n",
"\n",
" airline_sentiment_gold name negativereason_gold retweet_count \\\n",
"0 NaN cairdin NaN 0 \n",
"1 NaN jnardino NaN 0 \n",
"2 NaN yvonnalynn NaN 0 \n",
"3 NaN jnardino NaN 0 \n",
"4 NaN jnardino NaN 0 \n",
"\n",
" text tweet_coord \\\n",
"0 @VirginAmerica What @dhepburn said. NaN \n",
"1 @VirginAmerica plus you've added commercials t... NaN \n",
"2 @VirginAmerica I didn't today... Must mean I n... NaN \n",
"3 @VirginAmerica it's really aggressive to blast... NaN \n",
"4 @VirginAmerica and it's a really big bad thing... NaN \n",
"\n",
" tweet_created tweet_location user_timezone \n",
"0 2015-02-24 11:35:52 -0800 NaN Eastern Time (US & Canada) \n",
"1 2015-02-24 11:15:59 -0800 NaN Pacific Time (US & Canada) \n",
"2 2015-02-24 11:15:48 -0800 Lets Play Central Time (US & Canada) \n",
"3 2015-02-24 11:15:36 -0800 NaN Pacific Time (US & Canada) \n",
"4 2015-02-24 11:14:45 -0800 NaN Pacific Time (US & Canada) "
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Import tweets from csv file and view the first few lines\n",
"df = pd.read_csv('Tweets.csv', sep=',')\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['tweet_id', 'airline_sentiment', 'airline_sentiment_confidence', 'negativereason', 'negativereason_confidence', 'airline', 'airline_sentiment_gold', 'name', 'negativereason_gold', 'retweet_count', 'text', 'tweet_coord', 'tweet_created', 'tweet_location', 'user_timezone']\n"
]
}
],
"source": [
"# get the feature names\n",
"features = df.columns.tolist()\n",
"print(features)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(14640, 15)"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Number of tweets and features\n",
"df.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have 14,640 tweets in the dataset and a number of features. Since I am using ULMFiT, I will only use the text (or their contexual embeddings) as features for fine-tuning the language model and the supervised text classifier.\n",
"\n",
"\n",
"### Pre-processing techniques\n",
"I am using the following tweet pre-processing techniques.\n",
"1. Remove unicode strings\n",
"2. Replace urls with empty string\n",
"3. Replace user mentions with empty string\n",
"4. Replace hashtags\n",
"5. Replace slang and abbreviations\n",
"6. Replace contractions\n",
"7. Remove numbers\n",
"8. Remove punctuation marks and special characters\n",
"9. Replace emoticons\n",
"10. Lowercase text\n",
"11. Replace negations\n",
"\n",
"I have tested, but omitted the spell correction feature, since the implementation is not very efficient and takes too long. For the same reeason, I have not applied stopword removal here."
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"# A subset of techniques for tweet pre-processing\n",
"# Originally published by Dimitrios Effrosynidis at \n",
"# https://github.com/Deffro/text-preprocessing-techniques\n",
"\n",
"def removeUnicode(text):\n",
" \"\"\" Removes unicode strings like \"\\u002c\" and \"x96\" \"\"\"\n",
" text = re.sub(r'(\\\\u[0-9A-Fa-f]+)',r'', text) \n",
" text = re.sub(r'[^\\x00-\\x7f]',r'',text)\n",
" return text\n",
"\n",
"def replaceURL(text):\n",
" \"\"\" Replaces url address with \"url\" \"\"\"\n",
" # text = re.sub('((www\\.[^\\s]+)|(https?://[^\\s]+))','url',text)\n",
" text = re.sub('((www\\.[^\\s]+)|(https?://[^\\s]+))','',text)\n",
" text = re.sub(r'#([^\\s]+)', r'\\1', text)\n",
" return text\n",
"\n",
"def replaceAtUser(text):\n",
" \"\"\" Replaces \"@user\" with \"atUser\" \"\"\"\n",
" # text = re.sub('@[^\\s]+','atUser',text)\n",
" text = re.sub('@[^\\s]+','',text)\n",
" return text\n",
"\n",
"def removeHashtagInFrontOfWord(text):\n",
" \"\"\" Removes hastag in front of a word \"\"\"\n",
" text = re.sub(r'#([^\\s]+)', r'\\1', text)\n",
" return text\n",
"\n",
"def removeNumbers(text):\n",
" \"\"\" Removes integers \"\"\"\n",
" text = ''.join([i for i in text if not i.isdigit()]) \n",
" return text\n",
"\n",
"def removeEmoticons(text):\n",
" \"\"\" Removes emoticons from text \"\"\"\n",
" text = re.sub(':\\)|;\\)|:-\\)|\\(-:|:-D|=D|:P|xD|X-p|\\^\\^|:-*|\\^\\.\\^|\\^\\-\\^|\\^\\_\\^|\\,-\\)|\\)-:|:\\'\\(|:\\(|:-\\(|:\\S|T\\.T|\\.\\_\\.|:<|:-\\S|:-<|\\*\\-\\*|:O|=O|=\\-O|O\\.o|XO|O\\_O|:-\\@|=/|:/|X\\-\\(|>\\.<|>=\\(|D:', '', text)\n",
" return text\n",
"\n",
"\"\"\" Creates a dictionary with slangs and their equivalents and replaces them \"\"\"\n",
"with open('slang.txt', encoding='utf8', errors='ignore') as file:\n",
" slang_map = dict(map(str.strip, line.partition('\\t')[::2])\n",
" for line in file if line.strip())\n",
"\n",
"slang_words = sorted(slang_map, key=len, reverse=True) # longest first for regex\n",
"regex = re.compile(r\"\\b({})\\b\".format(\"|\".join(map(re.escape, slang_words))))\n",
"replaceSlang = partial(regex.sub, lambda m: slang_map[m.group(1)])\n",
"\n",
"def replaceElongated(word):\n",
" \"\"\" Replaces an elongated word with its basic form, unless the word exists in the lexicon \"\"\"\n",
"\n",
" repeat_regexp = re.compile(r'(\\w*)(\\w)\\2(\\w*)')\n",
" repl = r'\\1\\2\\3'\n",
" if wordnet.synsets(word):\n",
" return word\n",
" repl_word = repeat_regexp.sub(repl, word)\n",
" if repl_word != word: \n",
" return replaceElongated(repl_word)\n",
" else: \n",
" return repl_word\n",
"\n",
"\"\"\" Replaces contractions from a string to their equivalents \"\"\"\n",
"contraction_patterns = [ (r'won\\'t', 'will not'), (r'can\\'t', 'cannot'), (r'i\\'m', 'i am'), (r'ain\\'t', 'is not'), (r'(\\w+)\\'ll', '\\g<1> will'), (r'(\\w+)n\\'t', '\\g<1> not'),\n",
" (r'(\\w+)\\'ve', '\\g<1> have'), (r'(\\w+)\\'s', '\\g<1> is'), (r'(\\w+)\\'re', '\\g<1> are'), (r'(\\w+)\\'d', '\\g<1> would'), (r'&', 'and'), (r'dammit', 'damn it'), (r'dont', 'do not'), (r'wont', 'will not') ]\n",
"def replaceContraction(text):\n",
" patterns = [(re.compile(regex), repl) for (regex, repl) in contraction_patterns]\n",
" for (pattern, repl) in patterns:\n",
" (text, count) = re.subn(pattern, repl, text)\n",
" return text\n",
"\n",
"\n",
"def lowercase(text):\n",
" \"\"\" Make all characters lowercase \"\"\"\n",
" return text.lower()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"### Spell Correction begin ###\n",
"\"\"\" Spell Correction http://norvig.com/spell-correct.html \"\"\"\n",
"def words(text): return re.findall(r'\\w+', text.lower())\n",
"\n",
"WORDS = Counter(words(open('corporaForSpellCorrection.txt').read()))\n",
"\n",
"def P(word, N=sum(WORDS.values())): \n",
" \"\"\"P robability of `word`. \"\"\"\n",
" return WORDS[word] / N\n",
"\n",
"def spellCorrection(word): \n",
" \"\"\" Most probable spelling correction for word. \"\"\"\n",
" return max(candidates(word), key=P)\n",
"\n",
"def candidates(word): \n",
" \"\"\" Generate possible spelling corrections for word. \"\"\"\n",
" return (known([word]) or known(edits1(word)) or known(edits2(word)) or [word])\n",
"\n",
"def known(words): \n",
" \"\"\" The subset of `words` that appear in the dictionary of WORDS. \"\"\"\n",
" return set(w for w in words if w in WORDS)\n",
"\n",
"def edits1(word):\n",
" \"\"\" All edits that are one edit away from `word`. \"\"\"\n",
" letters = 'abcdefghijklmnopqrstuvwxyz'\n",
" splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]\n",
" deletes = [L + R[1:] for L, R in splits if R]\n",
" transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]\n",
" replaces = [L + c + R[1:] for L, R in splits if R for c in letters]\n",
" inserts = [L + c + R for L, R in splits for c in letters]\n",
" return set(deletes + transposes + replaces + inserts)\n",
"\n",
"def edits2(word): \n",
" \"\"\" All edits that are two edits away from `word`. \"\"\"\n",
" return (e2 for e1 in edits1(word) for e2 in edits1(e1))\n",
"\n",
"### Spell Correction End ###"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"## Replace Negations Begin ###\n",
"\n",
"def replace(word, pos=None):\n",
" \"\"\" Creates a set of all antonyms for the word and if there is only one antonym, it returns it \"\"\"\n",
" antonyms = set()\n",
" for syn in wordnet.synsets(word, pos=pos):\n",
" for lemma in syn.lemmas():\n",
" for antonym in lemma.antonyms():\n",
" antonyms.add(antonym.name())\n",
" if len(antonyms) == 1:\n",
" return antonyms.pop()\n",
" else:\n",
" return None\n",
"\n",
"def replaceNegations(text):\n",
" \"\"\" Finds \"not\" and antonym for the next word and if found, replaces not and the next word with the antonym \"\"\"\n",
" i, l = 0, len(text)\n",
" words = []\n",
" while i < l:\n",
" word = text[i]\n",
" if word == 'not' and i+1 < l:\n",
" ant = replace(text[i+1])\n",
" if ant:\n",
" words.append(ant)\n",
" i += 2\n",
" continue\n",
" words.append(word)\n",
" i += 1\n",
" return words\n",
"\n",
"### Replace Negations End ###"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Some more methods for pre-processing\n",
"# Author: Manoj Pravakar Saha\n",
"\n",
"def removeSpecialCharacters(text):\n",
" \"\"\" Removes puncatuations from text \"\"\"\n",
" # translator = str.maketrans('', '', string.punctuation)\n",
" # return text.translate(translator)\n",
" return re.sub(r'[^\\w\\s]',' ',text)\n",
"\n",
"def replaceNegationsText(text):\n",
" \"\"\" Replace negations from the entire text string (not a single token) \"\"\"\n",
" tokens = nltk.word_tokenize(text)\n",
" tokens = replaceNegations(tokens) # Technique 6: finds \"not\" and antonym \n",
" # for the next word and if found, replaces not \n",
" # and the next word with the antonym\n",
" onlyOneSentence = \" \".join(tokens) # form again the sentence from the list of tokens\n",
" return onlyOneSentence\n",
"\n",
"def spellCorrectionText(text):\n",
" \"\"\" Correct misspelled words in entire text \"\"\"\n",
" onlyOneSentenceTokens = [] # tokens of one sentence each time\n",
" tokens = nltk.word_tokenize(text)\n",
" for token in tokens:\n",
" final_word = spellCorrection(token)\n",
" onlyOneSentenceTokens.append(final_word)\n",
" return \" \".join(onlyOneSentenceTokens)"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
"# Pre-processing techniques applied sequentially\n",
"df.text = df.text.apply(removeUnicode) # Remove Unicode characters\n",
"df.text = df.text.apply(lowercase) # Lowercase the text\n",
"df.text = df.text.apply(replaceURL) # Replace URLs with empty string\n",
"df.text = df.text.apply(replaceAtUser) # Replace @user with empty string\n",
"df.text = df.text.apply(removeHashtagInFrontOfWord) # Remove hashtags\n",
"df.text = df.text.apply(replaceSlang) # Replace slang and abbreviations\n",
"df.text = df.text.apply(replaceContraction) # Replace contractions with equivalent words\n",
"df.text = df.text.apply(removeNumbers) # Remove numbers from text\n",
"df.text = df.text.apply(removeEmoticons) # Remove emoticons from text\n",
"df.text = df.text.apply(removeSpecialCharacters) # Remove special characters\n",
"df.text = df.text.apply(replaceNegationsText) # Replace negations with antonyms\n"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>labels</th>\n",
" <th>text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>neutral</td>\n",
" <td>what said</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>positive</td>\n",
" <td>plus you have added commercials to the experie...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>neutral</td>\n",
" <td>i did not today must mean i need to take anoth...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>negative</td>\n",
" <td>it is really aggressive to blast obnoxious ent...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>negative</td>\n",
" <td>and it is a really big bad thing about it</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" labels text\n",
"0 neutral what said\n",
"1 positive plus you have added commercials to the experie...\n",
"2 neutral i did not today must mean i need to take anoth...\n",
"3 negative it is really aggressive to blast obnoxious ent...\n",
"4 negative and it is a really big bad thing about it"
]
},
"execution_count": 99,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create new dataframe with text and labels\n",
"df = df[['airline_sentiment', 'text']]\n",
"df = df.rename(columns={'airline_sentiment':'labels'})\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"# Change labels into integers \n",
"# df.loc[df['labels'] == 'positive', 'labels'] = 0\n",
"# df.loc[df['labels'] == 'neutral', 'labels'] = 1\n",
"# df.loc[df['labels'] == 'negative', 'labels'] = 2"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [],
"source": [
"# Divide data into training and test sets\n",
"test_df = df.sample(frac=0.2) # Randomly select 20% as test set\n",
"train_df = df.drop(test_df.index) # Keep the rest as training set"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trainset-sample size: 11712 \n",
"Testset-sample size: 2928\n"
]
}
],
"source": [
"# Print the number of samples in each set\n",
"print('Trainset-sample size: {} \\nTestset-sample size: {}'.\\\n",
" format(train_df.shape[0], test_df.shape[0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have splitted the data into training and test at 80-20 ratio, we should verify if both datasets contain similar distribution of sentiments."
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Train Train_% Test Test_%\n",
"negative 7333 62.61 1845 63.01\n",
"neutral 2497 21.32 602 20.56\n",
"positive 1882 16.07 481 16.43\n"
]
}
],
"source": [
"def column_value_counts(df, target_column, new_column):\n",
" '''\n",
" Get value counts of each categorical variable. Store this data in \n",
" a dataframe. Also add a column with relative percentage of each \n",
" categorical variable.\n",
" \n",
" :param df: A Pandas dataframe\n",
" :param target_column: Name of the column in the original dataframe (string)\n",
" :param new_column: Name of the new column where the frequency counts are stored \n",
" :type df: pandas.core.frame.DataFrame\n",
" :type target_column: str\n",
" :type new_column: str\n",
" :return: A Pandas dataframe containing the frequency counts\n",
" :rtype: pandas.core.frame.DataFrame\n",
" '''\n",
" df_value_counts = df[target_column].value_counts()\n",
" df = pd.DataFrame(df_value_counts)\n",
" df.columns = [new_column]\n",
" df[new_column+'_%'] = 100*df[new_column] / df[new_column].sum()\n",
" return df\n",
"\n",
"# Get frequency distribution of labels in each set\n",
"df_train = column_value_counts(train_df, 'labels', 'Train')\n",
"df_test = column_value_counts(test_df, 'labels', 'Test')\n",
"\n",
"label_count = pd.concat([df_train, df_test], axis=1) # Merge dataframes by index\n",
"label_count = label_count.fillna(0) # Replace Nan with 0 (zero)\n",
"label_count = label_count.round(2) # Rounding decimals to two digits after .\n",
"print(pronoun_count.sort_values(by=['Train'], ascending=False))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above table shows that in both sets the distribution of negative, postitive and neutral tweets are similar. Hence, we can now save this for future use."
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"# Save training and test sets into CSV files\n",
"train_df.to_csv('train.csv', header=False, index=False, encoding='utf-8')\n",
"test_df.to_csv('test.csv', header=False, index=False, encoding='utf-8')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Language Model and Supervised Classifier"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To complete this part, I have taken help from three different sources.\n",
"- https://docs.fast.ai/text.html#Fine-tuning-a-language-model\n",
"- https://www.analyticsvidhya.com/blog/2018/11/tutorial-text-classification-ulmfit-fastai-library/\n",
"- https://github.com/estorrs/twitter-celebrity-tweet-sentiment/blob/master/celebrity-twitter-sentiment.ipynb\n",
"\n",
"I have used fastai version 1.0 for this demo. The last example above is based on version 0.7. However, it helped me understand some of the issues related to retraining the ULMFiT model for a new dataset."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before we can proceed with retraing and fine-tuning the language model on our dataset, we need to download the ULMFiT pretrained models weights on WikiPedia using the following command."
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
"outputs": [],
"source": [
"# ! wget -nH -r -np -P http://files.fast.ai/models/wt103/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We would also use the LSTM model weights pre-trained on the same dataset."
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {},
"outputs": [],
"source": [
"# ! wget -nH -r -np -P http://files.fast.ai/models/wt103_v1/lstm_wt103.pth\n",
"# ! wget -nH -r -np -P http://files.fast.ai/models/wt103_v1/itos_wt103.pkl"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have downloaded the pretrained models, we can reload the training and test sets."
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"# Load training set in Pandas dataframe\n",
"train_df = pd.read_csv('train.csv', header=None, encoding='latin-1') \n",
"\n",
"# Load test set in Pandas dataframe\n",
"val_df = pd.read_csv('test.csv', header=None, encoding='latin-1') "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have both the pretrained model and the datasets, we can prepare the data for our language model and classifier model. Notice that we would need two different data objects. Fast.ai DataBunch class in version 1.0 has made it really easy to read preapre the data for training the models. The basic ppre-processing tasks are handle internally by the DataBunch class."
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"# Prepare data for language model\n",
"data_lm = TextLMDataBunch.from_df(train_df = train_df, valid_df = test_df, path = \"\")\n",
"\n",
"# Prepare data for classifier model\n",
"# I am using a batch size of 16\n",
"data_clas = TextClasDataBunch.from_df(path = \"\", train_df = train_df, valid_df = test_df, \n",
" vocab=data_lm.train_ds.vocab, bs=16)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Language model\n",
"Since we have the data ready, we can now re-train and fine-tune the language model. The AWD_LSTM model automatically use the pretrained weights. Probably this is why, the LSTM model provides the best downstream performance. I'll be using the LSTM model to train and fine-tune my model with the pre-trained weights."
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"# Initialize the learner object with the AWD_LSTM model\n",
"# I am using 50% dropout\n",
"learn = language_model_learner(data_lm, arch=AWD_LSTM, \n",
" pretrained_fnames=['lstm_wt103', 'itos_wt103'], drop_mult=0.5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Fast.ai provides two different methods to train the model - fit() and fit_one_cycle(). I have tested both. For re-training and fine-tuning I'll stick to fit_one_cycle(). To know more about these you can read - https://arxiv.org/abs/1803.09820"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:08 <p><table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>5.448481</td>\n",
" <td>4.554461</td>\n",
" <td>0.189872</td>\n",
" <td>00:08</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# train the learner object with learning rate = 1e-2\n",
"learn.fit_one_cycle(1, 1e-2)\n",
"#learn.fit(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start our fine-tuning process now. I'll use gradual unfreezing of the last layers before fine-tuning all layers."
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:08 <p><table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>4.660953</td>\n",
" <td>4.218702</td>\n",
" <td>0.224540</td>\n",
" <td>00:08</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# unfreeze the last layer\n",
"learn.freeze_to(-1)\n",
"learn.fit_one_cycle(1, 1e-2)"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:08 <p><table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>4.254465</td>\n",
" <td>3.933921</td>\n",
" <td>0.262137</td>\n",
" <td>00:08</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# unfreeze one more layer\n",
"learn.freeze_to(-2)\n",
"learn.fit_one_cycle(1, 1e-2)"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:10 <p><table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>3.981843</td>\n",
" <td>3.774588</td>\n",
" <td>0.281864</td>\n",
" <td>00:10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# unfreeze one more layer\n",
"learn.freeze_to(-3)\n",
"learn.fit_one_cycle(1, 1e-2)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:11 <p><table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>3.805899</td>\n",
" <td>3.721997</td>\n",
" <td>0.287291</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# unfreeze all layers\n",
"learn.unfreeze()\n",
"learn.fit_one_cycle(1, 1e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We are done with the fine-tuning for now. We can now save the model for future use."
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
"# Save the language model\n",
"learn.save_encoder('tweet_lm')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Classifier model\n",
"We have fine-tuned the language model. Now we can use the model to build our sentiment classifier. I am using a LSTM based classifier. However, we could have also gone for the RNN classifier. In that case, we would need to train our language model differently."
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
"# Initialize classifier model using the fine-tuned language model\n",
"# I am using the AWD_LSTM model with 50% dropout\n",
"learn_c = text_classifier_learner(data_clas, arch=AWD_LSTM, drop_mult=0.5)\n",
"learn_c.load_encoder('tweet_lm')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we go through a similar process of re-training and fine-tuning for the classifier model, as compared to the language model."
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:16 <p><table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.640346</td>\n",
" <td>0.539703</td>\n",
" <td>0.776639</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_c.fit_one_cycle(1, 1e-2)\n"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:16 <p><table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.649272</td>\n",
" <td>0.545790</td>\n",
" <td>0.771516</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_c.freeze_to(-1)\n",
"learn_c.fit_one_cycle(1, slice(5e-3/2., 5e-3))"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:32 <p><table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.583233</td>\n",
" <td>0.498302</td>\n",
" <td>0.803962</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_c.unfreeze()\n",
"learn_c.fit_one_cycle(1, slice(2e-3/100, 2e-3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our classifier model is now trained. We can now use the model to predict the classes. For a single tweet we need to use the predict() method. For batch prediction, we would need to use the get_preds() method."
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Category negative, tensor(0), tensor([0.9728, 0.0239, 0.0033]))"
]
},
"execution_count": 110,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Example of prediction on a single tweet\n",
"learn_c.predict('your ticket prices are bad')"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[tensor([[0.9796, 0.0153, 0.0051],\n",
" [0.8600, 0.1064, 0.0336],\n",
" [0.7035, 0.2479, 0.0486],\n",
" ...,\n",
" [0.0692, 0.1977, 0.7331],\n",
" [0.3106, 0.5750, 0.1144],\n",
" [0.1814, 0.6975, 0.1211]]), tensor([0, 0, 1, ..., 2, 1, 1])]"
]
},
"execution_count": 105,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Example of batch prediction on the validation set\n",
"# It ouputs class probabilities, which we would need to process \n",
"# to get the final class value\n",
"learn_c.get_preds(ds_type=DatasetType.Valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Remarks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The classifier model was able to achieve around 80% accuracy. This result can be improved by applying the following.\n",
"- We should test how different pre-processing techniques affects the accuracy.\n",
"- I have noticed that gradual unfreezing improves accuracy by a significant amount. This should be explored futher.\n",
"- However, I have not touched two other prominent features of ULMFiT - discriminative fine-tuning and slanted triangular learning rates. I believe, the language and classifier models can be improved a lot by trying out these two."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### END"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment