Skip to content

Instantly share code, notes, and snippets.

@ThomasDelteil
Created October 23, 2019 18:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save ThomasDelteil/029e3995560b8de877cfaf0afc7dc9e9 to your computer and use it in GitHub Desktop.
Save ThomasDelteil/029e3995560b8de877cfaf0afc7dc9e9 to your computer and use it in GitHub Desktop.
FineTuningBert-sentiment-analysis
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using BERT embeddings for text classification of movie reviews\n",
"\n",
"![](support/single-classification-bert.png)\n",
"\n",
"![](support/godfather.png)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import gluonnlp as nlp\n",
"import mxnet as mx\n",
"from mxnet import gluon, nd\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We are going to use the [imdb dataset](https://ai.stanford.edu/~amaas/data/sentiment/), trying to predict if a review is positive or if a review is negative"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def transform_label(data):\n",
" \"\"\"\n",
" Transform label into position / negative\n",
" \"\"\"\n",
" text, label = data\n",
" return text, 1 if label >= 5 else 0"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"train_dataset = nlp.data.IMDB('train')\n",
"test_dataset = nlp.data.IMDB('test')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"k = {i+1:0 for i in range(10)}\n",
"for elem in train_dataset:\n",
" k[elem[1]] += 1"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Distribution of the ratings\n"
]
},
{
"data": {
"text/plain": [
"{1: 5100,\n",
" 2: 2284,\n",
" 3: 2420,\n",
" 4: 2696,\n",
" 5: 0,\n",
" 6: 0,\n",
" 7: 2496,\n",
" 8: 3009,\n",
" 9: 2263,\n",
" 10: 4732}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(\"Distribution of the ratings\")\n",
"k"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Positive Review:\n",
"I went and saw this movie last night after being coaxed to by a few friends of mine. I'll admit that I was reluctant to see it because from what I knew of Ashton Kutcher he was only able to do comedy. I was wrong. Kutcher played the character of Jake Fischer very well, and Kevin Costner played Ben Randall with such professionalism. The sign of a good movie is that it can toy with our emotions. This one did exactly that. The entire theater (which was sold out) was overcome by laughter during the first half of the movie, and were moved to tears during the second half. While exiting the theater I not only saw many women in tears, but many full grown men as well, trying desperately not to let anyone see them crying. This movie was great, and I suggest that you go see it before you judge.\n",
"\n",
"Negative Review:\n",
"This is an example of why the majority of action films are the same. Generic and boring, there's really nothing worth watching here. A complete waste of the then barely-tapped talents of Ice-T and Ice Cube, who've each proven many times over that they are capable of acting, and acting well. Don't bother with this one, go see New Jack City, Ricochet or watch New York Undercover for Ice-T, or Boyz n the Hood, Higher Learning or Friday for Ice Cube and see the real deal. Ice-T's horribly cliched dialogue alone makes this film grate at the teeth, and I'm still wondering what the heck Bill Paxton was doing in this film? And why the heck does he always play the exact same character? From Aliens onward, every film I've seen with Bill Paxton has him playing the exact same irritating character, and at least in Aliens his character died, which made it somewhat gratifying...<br /><br />Overall, this is second-rate action trash. There are countless better films to see, and if you really want to see this one, watch Judgement Night, which is practically a carbon copy but has better acting and a better script. The only thing that made this at all worth watching was a decent hand on the camera - the cinematography was almost refreshing, which comes close to making up for the horrible film itself - but not quite. 4/10.\n"
]
}
],
"source": [
"print(\"Positive Review:\\n{}\".format(test_dataset[0][0]))\n",
"print()\n",
"print(\"Negative Review:\\n{}\".format(test_dataset[12501][0]))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"train_dataset = train_dataset.transform(transform_label)\n",
"test_dataset = test_dataset.transform(transform_label)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"There are 25000 training examples and 25000 test examples\n"
]
}
],
"source": [
"print(\"There are {} training examples and {} test examples\".format(len(train_dataset), len(test_dataset)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Sklearn TFIDF baseline\n",
"\n",
"Let's use sklearn to build a TFIDF pipeline with word tri-grams as a baseline"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
" FutureWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1-gram Accuracy train:0.9356%, test:0.88468%\n",
"2-gram Accuracy train:0.9498%, test:0.88632%\n",
"3-gram Accuracy train:0.94928%, test:0.88728%\n"
]
}
],
"source": [
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.feature_selection import chi2\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"for n_gram in [1,2,3]:\n",
" \n",
" text_clf = Pipeline([\n",
" ('tfidf', TfidfVectorizer(sublinear_tf=True, min_df=2+n_gram, norm='l2', encoding='latin-1', ngram_range=(1,n_gram), stop_words='english')),\n",
" ('clf', LogisticRegression()),\n",
" ])\n",
"\n",
" train_x = [elem[0] for elem in train_dataset]\n",
" test_x = [elem[0] for elem in test_dataset]\n",
"\n",
" train_y = np.array([elem[1] for elem in train_dataset])\n",
" test_y = np.array([elem[1] for elem in test_dataset])\n",
"\n",
" text_clf.fit(train_x, train_y)\n",
" \n",
" test_y_hat = text_clf.predict(test_x)\n",
" train_y_hat = text_clf.predict(train_x)\n",
" print(\"{}-gram Accuracy train:{}%, test:{}%\".format(n_gram, (train_y_hat == train_y).mean(), (test_y_hat == test_y).mean()))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fine-tuning BERT\n",
"\n",
"We download a pre-trained BERT a fine-tune it on the same dataset"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"bert_base, vocabulary = nlp.model.get_model('bert_24_1024_16',\n",
" dataset_name='book_corpus_wiki_en_uncased',\n",
" pretrained=True, use_pooler=True,\n",
" use_decoder=False, use_classifier=False, ctx=ctx)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 8"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to process the words the same way as it was done during training for that we use the `BERTTokenizer` and the `BERTSentenceTransform`"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# use the vocabulary from pre-trained model for tokenization\n",
"bert_tokenizer = nlp.data.BERTTokenizer(vocabulary)\n",
"max_len = 128\n",
"transform = nlp.data.BERTSentenceTransform(bert_tokenizer, max_len, pad=False, pair=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We create a custom network for BERT classification we take advantage of the pooler output which is the output of `[CLS]` token plus a non-linearity"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"class BERTTextClassifier(gluon.nn.Block):\n",
" def __init__(self, bert, num_classes):\n",
" super(BERTTextClassifier, self).__init__()\n",
" self.bert = bert\n",
" with self.name_scope():\n",
" self.classifier = gluon.nn.Dense(num_classes)\n",
" \n",
" def forward(self, inputs, seq_len, token_types):\n",
" out, pooler = self.bert(inputs, seq_len, token_types)\n",
" return self.classifier(pooler)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"net = BERTTextClassifier(bert_base, 2)\n",
"net.classifier.initialize(ctx=ctx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Data Loading**:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"def transform_fn(text, label):\n",
" data, length, token_type = transform([text])\n",
" return data.astype('float32'), length.astype('float32'), token_type.astype('float32'), label"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/tutorials/gluon-nlp/src/gluonnlp/data/batchify/batchify.py:228: UserWarning: Padding value 0 is used in data.batchify.Pad(). Please check whether this is intended (e.g. value of padding index in the vocabulary).\n",
" 'Padding value 0 is used in data.batchify.Pad(). '\n"
]
}
],
"source": [
"batchify_fn = nlp.data.batchify.Tuple(\n",
" nlp.data.batchify.Pad(axis=0), nlp.data.batchify.Stack(),\n",
" nlp.data.batchify.Pad(axis=0), nlp.data.batchify.Stack(np.float32))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"train_data = gluon.data.DataLoader(train_dataset.transform(transform_fn), batchify_fn=batchify_fn, shuffle=True, batch_size=batch_size, num_workers=8, thread_pool=True)\n",
"test_data = gluon.data.DataLoader(test_dataset.transform(transform_fn), batchify_fn=batchify_fn, shuffle=True, batch_size=batch_size*4, num_workers=8, thread_pool=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Training**"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"trainer = gluon.Trainer(net.collect_params(), 'bertadam', {'learning_rate':0.000005, 'wd':0.001, 'epsilon':1e-6})\n",
"loss_fn = gluon.loss.SoftmaxCELoss()\n",
"net.hybridize(static_alloc=True, static_shape=True)\n",
"\n",
"num_epoch = 3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Training loop"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Batch 0 Accuracy 0.5 Loss 0.8113940954208374\n",
"Batch 50 Accuracy 0.5465686274509803 Loss 0.6971047345329734\n",
"Batch 100 Accuracy 0.5160891089108911 Loss 0.6957627853544632\n",
"Batch 150 Accuracy 0.5099337748344371 Loss 0.6951458659393108\n",
"Batch 200 Accuracy 0.5055970149253731 Loss 0.6950178763166589\n",
"Batch 250 Accuracy 0.5079681274900398 Loss 0.6945380100690987\n",
"Batch 300 Accuracy 0.5095514950166113 Loss 0.694422420869238\n",
"Batch 350 Accuracy 0.5092592592592593 Loss 0.694160309272614\n",
"Batch 400 Accuracy 0.5105985037406484 Loss 0.6939230750029224\n",
"Batch 450 Accuracy 0.5124722838137472 Loss 0.6937596898385532\n",
"Batch 500 Accuracy 0.5104790419161677 Loss 0.693802093079466\n",
"Batch 550 Accuracy 0.5081669691470054 Loss 0.693887755572255\n",
"Batch 600 Accuracy 0.5056156405990017 Loss 0.6939957756766068\n",
"Batch 650 Accuracy 0.5032642089093702 Loss 0.6940928750690044\n",
"Batch 700 Accuracy 0.5030313837375179 Loss 0.6940981093554967\n",
"Batch 750 Accuracy 0.5058255659121171 Loss 0.6938340838517394\n",
"Batch 800 Accuracy 0.5032771535580525 Loss 0.6940455442659567\n",
"Batch 850 Accuracy 0.5020564042303173 Loss 0.694036140845329\n",
"Batch 900 Accuracy 0.5006936736958935 Loss 0.694033742347912\n",
"Batch 950 Accuracy 0.49763406940063093 Loss 0.6941702889844571\n",
"Batch 1000 Accuracy 0.498001998001998 Loss 0.694186062960477\n",
"Batch 1050 Accuracy 0.4978591817316841 Loss 0.6940937713710157\n",
"Batch 1100 Accuracy 0.49977293369663944 Loss 0.6940059263418199\n",
"Batch 1150 Accuracy 0.5018462206776716 Loss 0.693903494668359\n",
"Batch 1200 Accuracy 0.5007285595337219 Loss 0.694016430399797\n",
"Batch 1250 Accuracy 0.5007993605115907 Loss 0.6939503085031974\n",
"Batch 1300 Accuracy 0.5015372790161414 Loss 0.6939097235149645\n",
"Batch 1350 Accuracy 0.5004626202812731 Loss 0.693972889005829\n",
"Batch 1400 Accuracy 0.49937544610992146 Loss 0.694006844302061\n",
"Batch 1450 Accuracy 0.49922467263955894 Loss 0.693951176249677\n",
"Batch 1500 Accuracy 0.4999167221852099 Loss 0.6938975090824867\n",
"Batch 1550 Accuracy 0.5000805931656995 Loss 0.6938770916445035\n",
"Batch 1600 Accuracy 0.49960961898813244 Loss 0.6939126943365475\n",
"Batch 1650 Accuracy 0.49977286493034523 Loss 0.6939044400174137\n",
"Batch 1700 Accuracy 0.5001469723691946 Loss 0.6938736347644768\n",
"Batch 1750 Accuracy 0.499571673329526 Loss 0.693924063481582\n",
"Batch 1800 Accuracy 0.5 Loss 0.6938802580380691\n",
"Batch 1850 Accuracy 0.5018233387358185 Loss 0.6937092967061723\n",
"Batch 1900 Accuracy 0.5018411362440821 Loss 0.6937858826357838\n",
"Batch 1950 Accuracy 0.500961045617632 Loss 0.6938298243488916\n",
"Batch 2000 Accuracy 0.5006871564217891 Loss 0.6938200211417729\n",
"Batch 2050 Accuracy 0.5004266211604096 Loss 0.6938326578498294\n",
"Batch 2100 Accuracy 0.5015468824369348 Loss 0.6937640488346323\n",
"Batch 2150 Accuracy 0.502092050209205 Loss 0.6937241671751511\n",
"Batch 2200 Accuracy 0.5019309404815993 Loss 0.6937525955886529\n",
"Batch 2250 Accuracy 0.5022767658818303 Loss 0.6937186988074745\n",
"Batch 2300 Accuracy 0.5017383746197306 Loss 0.6937674962142818\n",
"Batch 2350 Accuracy 0.501807741386644 Loss 0.6937409031396214\n",
"Batch 2400 Accuracy 0.5009371095376927 Loss 0.6937702755687735\n",
"Batch 2450 Accuracy 0.5005099959200326 Loss 0.6937742447279172\n",
"Batch 2500 Accuracy 0.5001499400239904 Loss 0.6937726374294032\n",
"Batch 2550 Accuracy 0.49931399451195607 Loss 0.6937834389393865\n",
"Batch 2600 Accuracy 0.4987985390234525 Loss 0.6937976266640234\n",
"Batch 2650 Accuracy 0.4994813278008299 Loss 0.6937510959160458\n",
"Batch 2700 Accuracy 0.4991669751943725 Loss 0.6937807684075342\n",
"Batch 2750 Accuracy 0.4989094874591058 Loss 0.6937700033794529\n",
"Batch 2800 Accuracy 0.4988843270260621 Loss 0.6937682081315825\n",
"Batch 2850 Accuracy 0.49899158190108733 Loss 0.6937547440865047\n",
"Batch 2900 Accuracy 0.49909513960703206 Loss 0.693739429830231\n",
"Batch 2950 Accuracy 0.4988986784140969 Loss 0.6937331145030287\n",
"Batch 3000 Accuracy 0.4996251249583472 Loss 0.6936987800941353\n",
"Batch 3050 Accuracy 0.4996312684365782 Loss 0.6936925777409046\n",
"Batch 3100 Accuracy 0.4997178329571106 Loss 0.6936764980046759\n",
"Epoch 0, Accuracy ('accuracy', 0.49972), Loss 0.6936875\n",
"Batch 0 Accuracy 0.5 Loss 0.6856123208999634\n",
"Batch 50 Accuracy 0.4583333333333333 Loss 0.6955141553691789\n",
"Batch 100 Accuracy 0.48267326732673266 Loss 0.6941103038221302\n",
"Batch 150 Accuracy 0.48096026490066224 Loss 0.6942238712942363\n",
"Batch 200 Accuracy 0.4906716417910448 Loss 0.6936642471237562\n",
"Batch 250 Accuracy 0.4895418326693227 Loss 0.693647954568445\n",
"Batch 300 Accuracy 0.4908637873754153 Loss 0.693672357603561\n",
"Batch 350 Accuracy 0.4878917378917379 Loss 0.693731019979189\n",
"Batch 400 Accuracy 0.4878428927680798 Loss 0.6938717050148068\n",
"Batch 450 Accuracy 0.49029933481152993 Loss 0.6938042143760393\n",
"Batch 500 Accuracy 0.4880239520958084 Loss 0.6938397984304828\n",
"Batch 550 Accuracy 0.48593466424682397 Loss 0.6938430038424456\n",
"Batch 600 Accuracy 0.4902246256239601 Loss 0.6936775499493032\n",
"Batch 650 Accuracy 0.4911674347158218 Loss 0.6936897242673531\n",
"Batch 700 Accuracy 0.4957203994293866 Loss 0.6934668177714871\n",
"Batch 750 Accuracy 0.4933422103861518 Loss 0.693654391800516\n",
"Batch 800 Accuracy 0.493601747815231 Loss 0.693650301624922\n",
"Batch 850 Accuracy 0.49441833137485314 Loss 0.693529140234834\n",
"Batch 900 Accuracy 0.4911209766925638 Loss 0.693599993063263\n",
"Batch 950 Accuracy 0.49263932702418506 Loss 0.6935356132114222\n",
"Batch 1000 Accuracy 0.4926323676323676 Loss 0.6935521753636988\n",
"Batch 1050 Accuracy 0.4910799238820171 Loss 0.6936605437158063\n",
"Batch 1100 Accuracy 0.49364214350590374 Loss 0.6935811194368756\n",
"Batch 1150 Accuracy 0.4946785403996525 Loss 0.6936019247868701\n",
"Batch 1200 Accuracy 0.4929225645295587 Loss 0.6937399680767329\n",
"Batch 1250 Accuracy 0.4920063948840927 Loss 0.6937450039968026\n",
"Batch 1300 Accuracy 0.49279400461183703 Loss 0.693691901662183\n",
"Batch 1350 Accuracy 0.4932457438934123 Loss 0.6936686167798158\n",
"Batch 1400 Accuracy 0.4942897930049964 Loss 0.6936067394662295\n",
"Batch 1450 Accuracy 0.49353893866299103 Loss 0.6936395225814094\n",
"Batch 1500 Accuracy 0.4942538307794803 Loss 0.6936633715543804\n",
"Batch 1550 Accuracy 0.4951644100580271 Loss 0.6936317704555529\n",
"Batch 1600 Accuracy 0.49508119925046845 Loss 0.6936721372872423\n",
"Batch 1650 Accuracy 0.4950030284675954 Loss 0.6937146432323213\n",
"Batch 1700 Accuracy 0.4949294532627866 Loss 0.6937298200047766\n",
"Batch 1750 Accuracy 0.4957881210736722 Loss 0.6937139435947316\n",
"Batch 1800 Accuracy 0.4973625763464742 Loss 0.693623239367886\n",
"Batch 1850 Accuracy 0.4974338195569962 Loss 0.69365897814948\n",
"Batch 1900 Accuracy 0.49730405049973697 Loss 0.693665546094161\n",
"Batch 1950 Accuracy 0.4973731419784726 Loss 0.6936632055676576\n",
"Batch 2000 Accuracy 0.4976886556721639 Loss 0.6936287715517241\n",
"Batch 2050 Accuracy 0.4976231106777182 Loss 0.6936190496823196\n",
"Batch 2100 Accuracy 0.4982151356496906 Loss 0.6936090933260948\n",
"Batch 2150 Accuracy 0.49953509995351 Loss 0.6935010242329149\n",
"Batch 2200 Accuracy 0.5000567923671059 Loss 0.6934935358750284\n",
"Batch 2250 Accuracy 0.49916703687250114 Loss 0.6935682665343181\n",
"Batch 2300 Accuracy 0.49853324641460234 Loss 0.6935827196395589\n",
"Batch 2350 Accuracy 0.4979795831561038 Loss 0.6935873157366547\n",
"Batch 2400 Accuracy 0.4982299042065806 Loss 0.69355236913005\n",
"Batch 2450 Accuracy 0.4983170134638923 Loss 0.6935666106468024\n",
"Batch 2500 Accuracy 0.49810075969612155 Loss 0.6935709602877599\n",
"Batch 2550 Accuracy 0.4976969815758526 Loss 0.693574852309756\n",
"Batch 2600 Accuracy 0.49774125336409075 Loss 0.6935637641622212\n",
"Batch 2650 Accuracy 0.49872689551112787 Loss 0.6934963184588363\n",
"Batch 2700 Accuracy 0.4980562754535357 Loss 0.6935707044266013\n",
"Batch 2750 Accuracy 0.4969556524900036 Loss 0.6935945522650854\n",
"Batch 2800 Accuracy 0.49629596572652623 Loss 0.6936222990254597\n",
"Batch 2850 Accuracy 0.49622939319537007 Loss 0.6936271504241933\n",
"Batch 2900 Accuracy 0.49599276111685625 Loss 0.6936211045221475\n",
"Batch 2950 Accuracy 0.4959335818366655 Loss 0.6936311479424347\n",
"Batch 3000 Accuracy 0.4960846384538487 Loss 0.6936152304700517\n",
"Batch 3050 Accuracy 0.4963126843657817 Loss 0.6936248809304326\n",
"Batch 3100 Accuracy 0.4965736859077717 Loss 0.6936086330518381\n",
"Epoch 1, Accuracy ('accuracy', 0.49672), Loss 0.69359515625\n",
"Batch 0 Accuracy 0.25 Loss 0.7104008793830872\n",
"Batch 50 Accuracy 0.5073529411764706 Loss 0.6934510773303462\n",
"Batch 100 Accuracy 0.5247524752475248 Loss 0.6929469344639542\n",
"Batch 150 Accuracy 0.5091059602649006 Loss 0.6934733106600528\n",
"Batch 200 Accuracy 0.5105721393034826 Loss 0.6932625841738572\n",
"Batch 250 Accuracy 0.5134462151394422 Loss 0.6928929788657868\n",
"Batch 300 Accuracy 0.5157807308970099 Loss 0.6928477239767182\n",
"Batch 350 Accuracy 0.5138888888888888 Loss 0.693018073709602\n",
"Batch 400 Accuracy 0.5118453865336658 Loss 0.6929108198741427\n",
"Batch 450 Accuracy 0.5127494456762749 Loss 0.6929625111514343\n",
"Batch 500 Accuracy 0.5097305389221557 Loss 0.6931828481708459\n",
"Batch 550 Accuracy 0.5058983666061706 Loss 0.693242843267922\n",
"Batch 600 Accuracy 0.5056156405990017 Loss 0.6932559909915765\n",
"Batch 650 Accuracy 0.5036482334869432 Loss 0.6933147940218174\n",
"Batch 700 Accuracy 0.5030313837375179 Loss 0.6933085704155894\n",
"Batch 750 Accuracy 0.5021637816245007 Loss 0.6933090677274051\n",
"Batch 800 Accuracy 0.5015605493133583 Loss 0.6933596035960908\n",
"Batch 850 Accuracy 0.500587544065805 Loss 0.6934002563620005\n",
"Batch 900 Accuracy 0.5016648168701443 Loss 0.6933038946526429\n",
"Batch 950 Accuracy 0.5017087276550999 Loss 0.6933975620851407\n",
"Batch 1000 Accuracy 0.5013736263736264 Loss 0.6934364463661339\n",
"Batch 1050 Accuracy 0.5008325404376784 Loss 0.6934531054873335\n",
"Batch 1100 Accuracy 0.5011353315168029 Loss 0.6934898716011013\n",
"Batch 1150 Accuracy 0.5022806255430061 Loss 0.6934206222680006\n",
"Batch 1200 Accuracy 0.4997918401332223 Loss 0.6934522236515144\n",
"Batch 1250 Accuracy 0.5 Loss 0.693437096026304\n",
"Batch 1300 Accuracy 0.500576479631053 Loss 0.6934615537627306\n",
"Batch 1350 Accuracy 0.5013878608438194 Loss 0.6934563716546771\n",
"Batch 1400 Accuracy 0.5015167737330478 Loss 0.6935105470144094\n",
"Batch 1450 Accuracy 0.5031013094417643 Loss 0.693440895422338\n",
"Batch 1500 Accuracy 0.5027481678880746 Loss 0.6934464751363674\n",
"Batch 1550 Accuracy 0.5022566086395873 Loss 0.6934809730869197\n",
"Batch 1600 Accuracy 0.5021861336664585 Loss 0.6934927297636243\n",
"Batch 1650 Accuracy 0.500908540278619 Loss 0.6934966763751136\n",
"Batch 1700 Accuracy 0.5013962375073486 Loss 0.6934255412716784\n",
"Batch 1750 Accuracy 0.500642490005711 Loss 0.6934219787353655\n",
"Batch 1800 Accuracy 0.5002776235424764 Loss 0.6934382699715436\n",
"Batch 1850 Accuracy 0.4997974068071313 Loss 0.693457242284576\n",
"Batch 1900 Accuracy 0.5001315097317202 Loss 0.6934870961212849\n",
"Batch 1950 Accuracy 0.4998718605843157 Loss 0.6934887032090915\n",
"Batch 2000 Accuracy 0.5003748125937032 Loss 0.6934610087534357\n",
"Batch 2050 Accuracy 0.5001828376401756 Loss 0.6934643638011946\n",
"Batch 2100 Accuracy 0.49988100904331273 Loss 0.6934794117756425\n",
"Batch 2150 Accuracy 0.4996513249651325 Loss 0.6934828640748489\n",
"Batch 2200 Accuracy 0.4990345297592004 Loss 0.6935032416018287\n",
"Batch 2250 Accuracy 0.49888938249666814 Loss 0.6935142540815193\n",
"Batch 2300 Accuracy 0.4987505432420687 Loss 0.6935147613299923\n",
"Batch 2350 Accuracy 0.49829859634198215 Loss 0.6935221527242397\n",
"Batch 2400 Accuracy 0.49796959600166596 Loss 0.6935368624889369\n",
"Batch 2450 Accuracy 0.49775601795185637 Loss 0.693570096947037\n",
"Batch 2500 Accuracy 0.49670131947221113 Loss 0.6935967800379849\n",
"Batch 2550 Accuracy 0.496030968247746 Loss 0.6935988261343591\n",
"Batch 2600 Accuracy 0.494761630142253 Loss 0.6936276856407392\n",
"Batch 2650 Accuracy 0.49495473406261786 Loss 0.6936079821735902\n",
"Batch 2700 Accuracy 0.49453905960755273 Loss 0.6936315814021428\n",
"Batch 2750 Accuracy 0.49504725554343876 Loss 0.6936228179099646\n",
"Batch 2800 Accuracy 0.4949571581578008 Loss 0.6936225169303374\n",
"Batch 2850 Accuracy 0.4952648193616275 Loss 0.6936182017411654\n",
"Batch 2900 Accuracy 0.49551878662530163 Loss 0.6935980033151284\n",
"Batch 2950 Accuracy 0.49567943070145715 Loss 0.6935861420122416\n",
"Batch 3000 Accuracy 0.49579306897700764 Loss 0.6935903364243169\n",
"Batch 3050 Accuracy 0.49553425106522453 Loss 0.6935947934591118\n",
"Batch 3100 Accuracy 0.49568687520154786 Loss 0.6935833608362222\n",
"Epoch 2, Accuracy ('accuracy', 0.4954), Loss 0.693593359375\n"
]
}
],
"source": [
"for epoch in range(num_epoch):\n",
" accuracy = mx.metric.Accuracy()\n",
" running_loss = 0\n",
" for i, (inputs, seq_len, token_types, label) in enumerate(train_data):\n",
" inputs = inputs.as_in_context(ctx)\n",
" seq_len = seq_len.as_in_context(ctx)\n",
" token_types = token_types.as_in_context(ctx)\n",
" label = label.as_in_context(ctx)\n",
" with mx.autograd.record():\n",
" out = net(inputs, token_types, seq_len)\n",
" loss = loss_fn(out, label.astype('float32'))\n",
" loss.backward()\n",
" running_loss += loss.mean()\n",
" trainer.step(batch_size)\n",
" accuracy.update(label, out.softmax())\n",
" \n",
" if i % 50 == 0:\n",
" print(\"Batch\", i, \"Accuracy\", accuracy.get()[1],\"Loss\", running_loss.asscalar()/(i+1))\n",
" print(\"Epoch {}, Accuracy {}, Loss {}\".format(epoch, accuracy.get(), running_loss.asscalar()/(i+1)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Evaluation**"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9068627450980392\n",
"0.9077970297029703\n",
"0.9060430463576159\n",
"0.9040733830845771\n",
"0.9016434262948207\n",
"0.9011627906976745\n",
"0.9024216524216524\n",
"0.9017300498753117\n",
"0.9011917960088692\n",
"0.9009481037924152\n",
"0.9021098003629764\n",
"0.9019862728785357\n",
"0.901833717357911\n",
"0.9028619828815977\n",
"0.9029627163781625\n",
"Test Accuracy 0.9029731457800512\n"
]
}
],
"source": [
"accuracy = 0\n",
"for i, (inputs, seq_len, token_types, label) in enumerate(test_data):\n",
" inputs = inputs.as_in_context(ctx)\n",
" seq_len = seq_len.as_in_context(ctx)\n",
" token_types = token_types.as_in_context(ctx)\n",
" label = label.as_in_context(ctx)\n",
" out = net(inputs, token_types, seq_len)\n",
" accuracy += (out.argmax(axis=1).squeeze() == label).mean()\n",
" if i % 50 == 0 and i > 0:\n",
" print(accuracy.asscalar()/(i+1))\n",
"print(\"Test Accuracy {}\".format(accuracy.asscalar()/(i+1)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Final accuracies:\n",
"\n",
"| Model | Training Accuracy | Testing Accuracy |\n",
"|--------------|-------------------|------------------|\n",
"|TF-IDF 1-gram | 93.6% | 88.4% |\n",
"|TF-IDF 2-gram | 95.0% | 88.6% |\n",
"|TF-IDF 3-gram | 94.9% | 88.7% |\n",
"|BERT 1024 | **97.0%** | **90.3%** | "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment