Skip to content

Instantly share code, notes, and snippets.

@haven-jeon
Last active February 25, 2023 08:36
Show Gist options
  • Star 12 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save haven-jeon/3d7c538398e93dab2ed4899159a5d943 to your computer and use it in GitHub Desktop.
Save haven-jeon/3d7c538398e93dab2ed4899159a5d943 to your computer and use it in GitHub Desktop.
BERT with Naver Sentiment Movie Corpus
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from mxnet.gluon import nn, rnn\n",
"from mxnet import gluon, autograd\n",
"import gluonnlp as nlp\n",
"from mxnet import nd \n",
"import mxnet as mx\n",
"import time\n",
"import itertools\n",
"import random"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 버트 로딩 "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"ctx = mx.gpu()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"bert_base, vocabulary = nlp.model.get_model('bert_12_768_12',\n",
" dataset_name='wiki_multilingual_cased',\n",
" pretrained=True, ctx=ctx, use_pooler=True,\n",
" use_decoder=False, use_classifier=False)\n",
"#print(bert_base)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(array([ 2, 8982, 9356, 47869, 9566, 3, 8935, 22333, 38851,\n",
" 3], dtype=int32),\n",
" array(10, dtype=int32),\n",
" array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1], dtype=int32))]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds = gluon.data.SimpleDataset([['나 보기가 역겨워', '김소월']])\n",
"\n",
"tok = nlp.data.BERTTokenizer(vocab=vocabulary, lower=False)\n",
"\n",
"trans = nlp.data.BERTSentenceTransform(tok, max_seq_length=10)\n",
"\n",
"list(ds.transform(trans))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"#데이터는 https://github.com/e9t/nsmc 다운받는다. \n",
"dataset_train = nlp.data.TSVDataset(\"ratings_train.txt\", field_indices=[1,2], num_discard_samples=1)\n",
"dataset_test = nlp.data.TSVDataset(\"ratings_test.txt\", field_indices=[1,2], num_discard_samples=1)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class BERTDataset(mx.gluon.data.Dataset):\n",
" def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,\n",
" pad, pair):\n",
" transform = nlp.data.BERTSentenceTransform(\n",
" bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)\n",
" sent_dataset = gluon.data.SimpleDataset([[\n",
" i[sent_idx],\n",
" ] for i in dataset])\n",
" self.sentences = sent_dataset.transform(transform)\n",
" self.labels = gluon.data.SimpleDataset(\n",
" [np.array(np.int32(i[label_idx])) for i in dataset])\n",
"\n",
" def __getitem__(self, i):\n",
" return (self.sentences[i] + (self.labels[i], ))\n",
"\n",
" def __len__(self):\n",
" return (len(self.labels))\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"bert_tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=False)\n",
"max_len = 64"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"data_train = BERTDataset(dataset_train, 0, 1, bert_tokenizer, max_len, True, False)\n",
"data_test = BERTDataset(dataset_test, 0, 1, bert_tokenizer, max_len, True, False)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class BERTClassifier(nn.Block):\n",
" def __init__(self,\n",
" bert,\n",
" num_classes=2,\n",
" dropout=None,\n",
" prefix=None,\n",
" params=None):\n",
" super(BERTClassifier, self).__init__(prefix=prefix, params=params)\n",
" self.bert = bert\n",
" with self.name_scope():\n",
" self.classifier = nn.HybridSequential(prefix=prefix)\n",
" if dropout:\n",
" self.classifier.add(nn.Dropout(rate=dropout))\n",
" self.classifier.add(nn.Dense(units=num_classes))\n",
"\n",
" def forward(self, inputs, token_types, valid_length=None):\n",
" _, pooler = self.bert(inputs, token_types, valid_length)\n",
" return self.classifier(pooler)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"model = BERTClassifier(bert_base, num_classes=2, dropout=0.3)\n",
"# 분류 레이어만 초기화 한다. \n",
"model.classifier.initialize(ctx=ctx)\n",
"model.hybridize()\n",
"\n",
"# softmax cross entropy loss for classification\n",
"loss_function = gluon.loss.SoftmaxCELoss()\n",
"\n",
"metric = mx.metric.Accuracy()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 64\n",
"lr = 5e-5\n",
"\n",
"train_dataloader = mx.gluon.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)\n",
"test_dataloader = mx.gluon.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"trainer = gluon.Trainer(model.collect_params(), 'bertadam',\n",
" {'learning_rate': lr, 'epsilon': 1e-9, 'wd':0.01})\n",
"\n",
"log_interval = 4\n",
"num_epochs = 4"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# LayerNorm과 Bias에는 Weight Decay를 적용하지 않는다. \n",
"for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():\n",
" v.wd_mult = 0.0\n",
"params = [\n",
" p for p in model.collect_params().values() if p.grad_req != 'null'\n",
"]\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def evaluate_accuracy(model, data_iter, ctx=ctx):\n",
" acc = mx.metric.Accuracy()\n",
" i = 0\n",
" for i, (t,v,s, label) in enumerate(data_iter):\n",
" token_ids = t.as_in_context(ctx)\n",
" valid_length = v.as_in_context(ctx)\n",
" segment_ids = s.as_in_context(ctx)\n",
" label = label.as_in_context(ctx)\n",
" output = model(token_ids, segment_ids, valid_length.astype('float32'))\n",
" acc.update(preds=output, labels=label)\n",
" if i > 1000:\n",
" break\n",
" i += 1\n",
" return(acc.get()[1])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"#learning rate warmup을 위한 준비 \n",
"step_size = batch_size \n",
"num_train_examples = len(data_train)\n",
"num_train_steps = int(num_train_examples / step_size * num_epochs)\n",
"warmup_ratio = 0.1\n",
"num_warmup_steps = int(num_train_steps * warmup_ratio)\n",
"step_num = 0"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Epoch 1 Batch 50/2344] loss=8.4847, lr=0.0000026681, acc=0.556\n",
"[Epoch 1 Batch 100/2344] loss=7.6343, lr=0.0000053362, acc=0.612\n",
"[Epoch 1 Batch 150/2344] loss=6.7952, lr=0.0000080043, acc=0.648\n",
"[Epoch 1 Batch 200/2344] loss=6.1983, lr=0.0000106724, acc=0.676\n",
"[Epoch 1 Batch 250/2344] loss=6.3750, lr=0.0000133404, acc=0.691\n",
"[Epoch 1 Batch 300/2344] loss=5.9387, lr=0.0000160085, acc=0.705\n",
"[Epoch 1 Batch 350/2344] loss=5.7424, lr=0.0000186766, acc=0.717\n",
"[Epoch 1 Batch 400/2344] loss=5.7524, lr=0.0000213447, acc=0.726\n",
"[Epoch 1 Batch 450/2344] loss=5.2769, lr=0.0000240128, acc=0.735\n",
"[Epoch 1 Batch 500/2344] loss=5.5889, lr=0.0000266809, acc=0.740\n",
"[Epoch 1 Batch 550/2344] loss=5.3694, lr=0.0000293490, acc=0.745\n",
"[Epoch 1 Batch 600/2344] loss=5.1612, lr=0.0000320171, acc=0.751\n",
"[Epoch 1 Batch 650/2344] loss=5.1401, lr=0.0000346852, acc=0.756\n",
"[Epoch 1 Batch 700/2344] loss=5.1864, lr=0.0000373533, acc=0.760\n",
"[Epoch 1 Batch 750/2344] loss=5.3412, lr=0.0000400213, acc=0.762\n",
"[Epoch 1 Batch 800/2344] loss=5.3824, lr=0.0000426894, acc=0.765\n",
"[Epoch 1 Batch 850/2344] loss=5.0335, lr=0.0000453575, acc=0.768\n",
"[Epoch 1 Batch 900/2344] loss=5.2395, lr=0.0000480256, acc=0.770\n",
"[Epoch 1 Batch 950/2344] loss=5.4929, lr=0.0000499230, acc=0.772\n",
"[Epoch 1 Batch 1000/2344] loss=5.4369, lr=0.0000496267, acc=0.773\n",
"[Epoch 1 Batch 1050/2344] loss=4.9858, lr=0.0000493304, acc=0.776\n",
"[Epoch 1 Batch 1100/2344] loss=5.2293, lr=0.0000490341, acc=0.777\n",
"[Epoch 1 Batch 1150/2344] loss=4.8173, lr=0.0000487379, acc=0.780\n",
"[Epoch 1 Batch 1200/2344] loss=4.7950, lr=0.0000484416, acc=0.782\n",
"[Epoch 1 Batch 1250/2344] loss=5.1266, lr=0.0000481453, acc=0.783\n",
"[Epoch 1 Batch 1300/2344] loss=4.9907, lr=0.0000478490, acc=0.784\n",
"[Epoch 1 Batch 1350/2344] loss=5.1889, lr=0.0000475527, acc=0.786\n",
"[Epoch 1 Batch 1400/2344] loss=4.7023, lr=0.0000472565, acc=0.787\n",
"[Epoch 1 Batch 1450/2344] loss=4.7529, lr=0.0000469602, acc=0.788\n",
"[Epoch 1 Batch 1500/2344] loss=4.8886, lr=0.0000466639, acc=0.790\n",
"[Epoch 1 Batch 1550/2344] loss=4.5845, lr=0.0000463676, acc=0.791\n",
"[Epoch 1 Batch 1600/2344] loss=4.9012, lr=0.0000460713, acc=0.792\n",
"[Epoch 1 Batch 1650/2344] loss=4.7183, lr=0.0000457751, acc=0.793\n",
"[Epoch 1 Batch 1700/2344] loss=4.6612, lr=0.0000454788, acc=0.795\n",
"[Epoch 1 Batch 1750/2344] loss=4.6324, lr=0.0000451825, acc=0.796\n",
"[Epoch 1 Batch 1800/2344] loss=4.6862, lr=0.0000448862, acc=0.797\n",
"[Epoch 1 Batch 1850/2344] loss=4.6235, lr=0.0000445900, acc=0.798\n",
"[Epoch 1 Batch 1900/2344] loss=4.7111, lr=0.0000442937, acc=0.799\n",
"[Epoch 1 Batch 1950/2344] loss=4.7050, lr=0.0000439974, acc=0.799\n",
"[Epoch 1 Batch 2000/2344] loss=4.2282, lr=0.0000437011, acc=0.801\n",
"[Epoch 1 Batch 2050/2344] loss=4.5042, lr=0.0000434048, acc=0.802\n",
"[Epoch 1 Batch 2100/2344] loss=4.4780, lr=0.0000431086, acc=0.803\n",
"[Epoch 1 Batch 2150/2344] loss=4.7802, lr=0.0000428123, acc=0.803\n",
"[Epoch 1 Batch 2200/2344] loss=4.3908, lr=0.0000425160, acc=0.804\n",
"[Epoch 1 Batch 2250/2344] loss=4.5963, lr=0.0000422197, acc=0.805\n",
"[Epoch 1 Batch 2300/2344] loss=4.2460, lr=0.0000419234, acc=0.806\n",
"Test Acc : 0.84662\n",
"[Epoch 2 Batch 50/2344] loss=4.4179, lr=0.0000413664, acc=0.846\n",
"[Epoch 2 Batch 100/2344] loss=4.4742, lr=0.0000410702, acc=0.843\n",
"[Epoch 2 Batch 150/2344] loss=4.3172, lr=0.0000407739, acc=0.845\n",
"[Epoch 2 Batch 200/2344] loss=4.1374, lr=0.0000404776, acc=0.845\n",
"[Epoch 2 Batch 250/2344] loss=4.2287, lr=0.0000401813, acc=0.847\n",
"[Epoch 2 Batch 300/2344] loss=4.1257, lr=0.0000398850, acc=0.848\n",
"[Epoch 2 Batch 350/2344] loss=4.0771, lr=0.0000395888, acc=0.850\n",
"[Epoch 2 Batch 400/2344] loss=4.2243, lr=0.0000392925, acc=0.850\n",
"[Epoch 2 Batch 450/2344] loss=3.8879, lr=0.0000389962, acc=0.852\n",
"[Epoch 2 Batch 500/2344] loss=4.1791, lr=0.0000386999, acc=0.852\n",
"[Epoch 2 Batch 550/2344] loss=3.9588, lr=0.0000384037, acc=0.853\n",
"[Epoch 2 Batch 600/2344] loss=3.7946, lr=0.0000381074, acc=0.854\n",
"[Epoch 2 Batch 650/2344] loss=3.8225, lr=0.0000378111, acc=0.856\n",
"[Epoch 2 Batch 700/2344] loss=3.7906, lr=0.0000375148, acc=0.857\n",
"[Epoch 2 Batch 750/2344] loss=4.0131, lr=0.0000372185, acc=0.857\n",
"[Epoch 2 Batch 800/2344] loss=3.7381, lr=0.0000369223, acc=0.858\n",
"[Epoch 2 Batch 850/2344] loss=3.8436, lr=0.0000366260, acc=0.859\n",
"[Epoch 2 Batch 900/2344] loss=3.9942, lr=0.0000363297, acc=0.859\n",
"[Epoch 2 Batch 950/2344] loss=3.9729, lr=0.0000360334, acc=0.859\n",
"[Epoch 2 Batch 1000/2344] loss=3.8688, lr=0.0000357371, acc=0.860\n",
"[Epoch 2 Batch 1050/2344] loss=3.4891, lr=0.0000354409, acc=0.861\n",
"[Epoch 2 Batch 1100/2344] loss=3.9726, lr=0.0000351446, acc=0.861\n",
"[Epoch 2 Batch 1150/2344] loss=3.5700, lr=0.0000348483, acc=0.862\n",
"[Epoch 2 Batch 1200/2344] loss=3.8338, lr=0.0000345520, acc=0.862\n",
"[Epoch 2 Batch 1250/2344] loss=3.7580, lr=0.0000342557, acc=0.863\n",
"[Epoch 2 Batch 1300/2344] loss=3.9778, lr=0.0000339595, acc=0.863\n",
"[Epoch 2 Batch 1350/2344] loss=3.8423, lr=0.0000336632, acc=0.863\n",
"[Epoch 2 Batch 1400/2344] loss=3.6291, lr=0.0000333669, acc=0.864\n",
"[Epoch 2 Batch 1450/2344] loss=3.7309, lr=0.0000330706, acc=0.864\n",
"[Epoch 2 Batch 1500/2344] loss=3.7647, lr=0.0000327744, acc=0.864\n",
"[Epoch 2 Batch 1550/2344] loss=3.5611, lr=0.0000324781, acc=0.865\n",
"[Epoch 2 Batch 1600/2344] loss=3.7591, lr=0.0000321818, acc=0.865\n",
"[Epoch 2 Batch 1650/2344] loss=3.6662, lr=0.0000318855, acc=0.865\n",
"[Epoch 2 Batch 1700/2344] loss=3.6121, lr=0.0000315892, acc=0.866\n",
"[Epoch 2 Batch 1750/2344] loss=3.9010, lr=0.0000312930, acc=0.866\n",
"[Epoch 2 Batch 1800/2344] loss=3.5127, lr=0.0000309967, acc=0.866\n",
"[Epoch 2 Batch 1850/2344] loss=3.5795, lr=0.0000307004, acc=0.867\n",
"[Epoch 2 Batch 1900/2344] loss=3.7448, lr=0.0000304041, acc=0.867\n",
"[Epoch 2 Batch 1950/2344] loss=3.6231, lr=0.0000301078, acc=0.867\n",
"[Epoch 2 Batch 2000/2344] loss=3.1706, lr=0.0000298116, acc=0.868\n",
"[Epoch 2 Batch 2050/2344] loss=3.6858, lr=0.0000295153, acc=0.868\n",
"[Epoch 2 Batch 2100/2344] loss=3.4976, lr=0.0000292190, acc=0.868\n",
"[Epoch 2 Batch 2150/2344] loss=3.8168, lr=0.0000289227, acc=0.868\n",
"[Epoch 2 Batch 2200/2344] loss=3.4934, lr=0.0000286265, acc=0.869\n",
"[Epoch 2 Batch 2250/2344] loss=3.5244, lr=0.0000283302, acc=0.869\n",
"[Epoch 2 Batch 2300/2344] loss=3.1572, lr=0.0000280339, acc=0.869\n",
"Test Acc : 0.86312\n",
"[Epoch 3 Batch 50/2344] loss=3.4649, lr=0.0000274769, acc=0.885\n",
"[Epoch 3 Batch 100/2344] loss=3.4550, lr=0.0000271806, acc=0.885\n",
"[Epoch 3 Batch 150/2344] loss=3.3320, lr=0.0000268843, acc=0.886\n",
"[Epoch 3 Batch 200/2344] loss=3.2314, lr=0.0000265881, acc=0.887\n",
"[Epoch 3 Batch 250/2344] loss=3.4850, lr=0.0000262918, acc=0.886\n",
"[Epoch 3 Batch 300/2344] loss=3.2321, lr=0.0000259955, acc=0.887\n",
"[Epoch 3 Batch 350/2344] loss=3.1810, lr=0.0000256992, acc=0.888\n",
"[Epoch 3 Batch 400/2344] loss=3.3420, lr=0.0000254029, acc=0.888\n",
"[Epoch 3 Batch 450/2344] loss=3.0145, lr=0.0000251067, acc=0.888\n",
"[Epoch 3 Batch 500/2344] loss=3.3181, lr=0.0000248104, acc=0.888\n",
"[Epoch 3 Batch 550/2344] loss=3.1762, lr=0.0000245141, acc=0.889\n",
"[Epoch 3 Batch 600/2344] loss=2.9531, lr=0.0000242178, acc=0.890\n",
"[Epoch 3 Batch 650/2344] loss=3.0685, lr=0.0000239215, acc=0.891\n",
"[Epoch 3 Batch 700/2344] loss=2.8686, lr=0.0000236253, acc=0.891\n",
"[Epoch 3 Batch 750/2344] loss=3.2135, lr=0.0000233290, acc=0.892\n",
"[Epoch 3 Batch 800/2344] loss=2.9537, lr=0.0000230327, acc=0.892\n",
"[Epoch 3 Batch 850/2344] loss=2.9990, lr=0.0000227364, acc=0.893\n",
"[Epoch 3 Batch 900/2344] loss=3.0884, lr=0.0000224402, acc=0.893\n",
"[Epoch 3 Batch 950/2344] loss=3.0450, lr=0.0000221439, acc=0.893\n",
"[Epoch 3 Batch 1000/2344] loss=3.0148, lr=0.0000218476, acc=0.893\n",
"[Epoch 3 Batch 1050/2344] loss=2.6822, lr=0.0000215513, acc=0.894\n",
"[Epoch 3 Batch 1100/2344] loss=3.2307, lr=0.0000212550, acc=0.894\n",
"[Epoch 3 Batch 1150/2344] loss=2.8391, lr=0.0000209588, acc=0.894\n",
"[Epoch 3 Batch 1200/2344] loss=2.8417, lr=0.0000206625, acc=0.895\n",
"[Epoch 3 Batch 1250/2344] loss=3.1252, lr=0.0000203662, acc=0.895\n",
"[Epoch 3 Batch 1300/2344] loss=3.0921, lr=0.0000200699, acc=0.895\n",
"[Epoch 3 Batch 1350/2344] loss=3.1037, lr=0.0000197736, acc=0.895\n",
"[Epoch 3 Batch 1400/2344] loss=2.8025, lr=0.0000194774, acc=0.896\n",
"[Epoch 3 Batch 1450/2344] loss=2.9914, lr=0.0000191811, acc=0.896\n",
"[Epoch 3 Batch 1500/2344] loss=3.0012, lr=0.0000188848, acc=0.896\n",
"[Epoch 3 Batch 1550/2344] loss=2.7741, lr=0.0000185885, acc=0.897\n",
"[Epoch 3 Batch 1600/2344] loss=3.0045, lr=0.0000182922, acc=0.897\n",
"[Epoch 3 Batch 1650/2344] loss=2.7965, lr=0.0000179960, acc=0.898\n",
"[Epoch 3 Batch 1700/2344] loss=2.7917, lr=0.0000176997, acc=0.898\n",
"[Epoch 3 Batch 1750/2344] loss=2.9973, lr=0.0000174034, acc=0.898\n",
"[Epoch 3 Batch 1800/2344] loss=2.8359, lr=0.0000171071, acc=0.898\n",
"[Epoch 3 Batch 1850/2344] loss=2.6977, lr=0.0000168109, acc=0.899\n",
"[Epoch 3 Batch 1900/2344] loss=2.9400, lr=0.0000165146, acc=0.899\n",
"[Epoch 3 Batch 1950/2344] loss=2.9390, lr=0.0000162183, acc=0.899\n",
"[Epoch 3 Batch 2000/2344] loss=2.6226, lr=0.0000159220, acc=0.900\n",
"[Epoch 3 Batch 2050/2344] loss=2.9510, lr=0.0000156257, acc=0.900\n",
"[Epoch 3 Batch 2100/2344] loss=2.7680, lr=0.0000153295, acc=0.900\n",
"[Epoch 3 Batch 2150/2344] loss=3.0785, lr=0.0000150332, acc=0.900\n",
"[Epoch 3 Batch 2200/2344] loss=2.6541, lr=0.0000147369, acc=0.901\n",
"[Epoch 3 Batch 2250/2344] loss=2.8383, lr=0.0000144406, acc=0.901\n",
"[Epoch 3 Batch 2300/2344] loss=2.4659, lr=0.0000141443, acc=0.901\n",
"Test Acc : 0.86686\n",
"[Epoch 4 Batch 50/2344] loss=2.6485, lr=0.0000135873, acc=0.919\n",
"[Epoch 4 Batch 100/2344] loss=2.6904, lr=0.0000132911, acc=0.916\n",
"[Epoch 4 Batch 150/2344] loss=2.6460, lr=0.0000129948, acc=0.915\n",
"[Epoch 4 Batch 200/2344] loss=2.5486, lr=0.0000126985, acc=0.915\n",
"[Epoch 4 Batch 250/2344] loss=2.6714, lr=0.0000124022, acc=0.915\n",
"[Epoch 4 Batch 300/2344] loss=2.4965, lr=0.0000121059, acc=0.916\n",
"[Epoch 4 Batch 350/2344] loss=2.5397, lr=0.0000118097, acc=0.917\n",
"[Epoch 4 Batch 400/2344] loss=2.6806, lr=0.0000115134, acc=0.916\n",
"[Epoch 4 Batch 450/2344] loss=2.4205, lr=0.0000112171, acc=0.916\n",
"[Epoch 4 Batch 500/2344] loss=2.6458, lr=0.0000109208, acc=0.916\n",
"[Epoch 4 Batch 550/2344] loss=2.4386, lr=0.0000106246, acc=0.916\n",
"[Epoch 4 Batch 600/2344] loss=2.1655, lr=0.0000103283, acc=0.918\n",
"[Epoch 4 Batch 650/2344] loss=2.3844, lr=0.0000100320, acc=0.919\n",
"[Epoch 4 Batch 700/2344] loss=2.1457, lr=0.0000097357, acc=0.919\n",
"[Epoch 4 Batch 750/2344] loss=2.5092, lr=0.0000094394, acc=0.919\n",
"[Epoch 4 Batch 800/2344] loss=2.3039, lr=0.0000091432, acc=0.919\n",
"[Epoch 4 Batch 850/2344] loss=2.5299, lr=0.0000088469, acc=0.919\n",
"[Epoch 4 Batch 900/2344] loss=2.4091, lr=0.0000085506, acc=0.919\n",
"[Epoch 4 Batch 950/2344] loss=2.4068, lr=0.0000082543, acc=0.919\n",
"[Epoch 4 Batch 1000/2344] loss=2.2997, lr=0.0000079580, acc=0.920\n",
"[Epoch 4 Batch 1050/2344] loss=1.9540, lr=0.0000076618, acc=0.920\n",
"[Epoch 4 Batch 1100/2344] loss=2.6630, lr=0.0000073655, acc=0.920\n",
"[Epoch 4 Batch 1150/2344] loss=2.2043, lr=0.0000070692, acc=0.921\n",
"[Epoch 4 Batch 1200/2344] loss=2.3991, lr=0.0000067729, acc=0.921\n",
"[Epoch 4 Batch 1250/2344] loss=2.4776, lr=0.0000064767, acc=0.921\n",
"[Epoch 4 Batch 1300/2344] loss=2.3768, lr=0.0000061804, acc=0.921\n",
"[Epoch 4 Batch 1350/2344] loss=2.3717, lr=0.0000058841, acc=0.921\n",
"[Epoch 4 Batch 1400/2344] loss=2.2483, lr=0.0000055878, acc=0.922\n",
"[Epoch 4 Batch 1450/2344] loss=2.3297, lr=0.0000052915, acc=0.922\n",
"[Epoch 4 Batch 1500/2344] loss=2.2679, lr=0.0000049953, acc=0.922\n",
"[Epoch 4 Batch 1550/2344] loss=2.0294, lr=0.0000046990, acc=0.922\n",
"[Epoch 4 Batch 1600/2344] loss=2.5505, lr=0.0000044027, acc=0.923\n",
"[Epoch 4 Batch 1650/2344] loss=2.1041, lr=0.0000041064, acc=0.923\n",
"[Epoch 4 Batch 1700/2344] loss=2.1665, lr=0.0000038101, acc=0.923\n",
"[Epoch 4 Batch 1750/2344] loss=2.4083, lr=0.0000035139, acc=0.923\n",
"[Epoch 4 Batch 1800/2344] loss=2.3939, lr=0.0000032176, acc=0.923\n",
"[Epoch 4 Batch 1850/2344] loss=2.2320, lr=0.0000029213, acc=0.924\n",
"[Epoch 4 Batch 1900/2344] loss=2.2836, lr=0.0000026250, acc=0.924\n",
"[Epoch 4 Batch 1950/2344] loss=2.3638, lr=0.0000023288, acc=0.924\n",
"[Epoch 4 Batch 2000/2344] loss=2.0511, lr=0.0000020325, acc=0.924\n",
"[Epoch 4 Batch 2050/2344] loss=2.3468, lr=0.0000017362, acc=0.924\n",
"[Epoch 4 Batch 2100/2344] loss=2.1864, lr=0.0000014399, acc=0.924\n",
"[Epoch 4 Batch 2150/2344] loss=2.5553, lr=0.0000011436, acc=0.924\n",
"[Epoch 4 Batch 2200/2344] loss=2.2028, lr=0.0000008474, acc=0.924\n",
"[Epoch 4 Batch 2250/2344] loss=2.3454, lr=0.0000005511, acc=0.924\n",
"[Epoch 4 Batch 2300/2344] loss=2.0258, lr=0.0000002548, acc=0.925\n",
"Test Acc : 0.87136\n"
]
}
],
"source": [
"for epoch_id in range(num_epochs):\n",
" metric.reset()\n",
" step_loss = 0\n",
" for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(train_dataloader):\n",
" step_num += 1\n",
" if step_num < num_warmup_steps:\n",
" new_lr = lr * step_num / num_warmup_steps\n",
" else:\n",
" offset = (step_num - num_warmup_steps) * lr / (\n",
" num_train_steps - num_warmup_steps)\n",
" new_lr = lr - offset\n",
" trainer.set_learning_rate(new_lr)\n",
" with mx.autograd.record():\n",
" # load data to GPU\n",
" token_ids = token_ids.as_in_context(ctx)\n",
" valid_length = valid_length.as_in_context(ctx)\n",
" segment_ids = segment_ids.as_in_context(ctx)\n",
" label = label.as_in_context(ctx)\n",
"\n",
" # forward computation\n",
" out = model(token_ids, segment_ids, valid_length.astype('float32'))\n",
" ls = loss_function(out, label).mean()\n",
"\n",
" # backward computation\n",
" ls.backward()\n",
" trainer.allreduce_grads()\n",
" nlp.utils.clip_grad_global_norm(params, 1)\n",
" trainer.update(token_ids.shape[0])\n",
"\n",
" step_loss += ls.asscalar()\n",
" metric.update([label], [out])\n",
" if (batch_id + 1) % (50) == 0:\n",
" print('[Epoch {} Batch {}/{}] loss={:.4f}, lr={:.10f}, acc={:.3f}'\n",
" .format(epoch_id + 1, batch_id + 1, len(train_dataloader),\n",
" step_loss / log_interval,\n",
" trainer.learning_rate, metric.get()[1]))\n",
" step_loss = 0\n",
" test_acc = evaluate_accuracy(model, test_dataloader, ctx)\n",
" print('Test Acc : {}'.format(test_acc))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"toc_cell": false,
"toc_position": {},
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@haven-jeon
Copy link
Author

pip install --pre --upgrade mxnet
#pip install --pre --upgrade mxnet-cu90
pip install -U https://github.com/dmlc/gluon-nlp/archive/master.zip

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment