Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stas00/2c711f43422f8469f188bf302914bfa5 to your computer and use it in GitHub Desktop.
Save stas00/2c711f43422f8469f188bf302914bfa5 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5"
},
"outputs": [],
"source": [
"#import os\n",
"#os.environ['CUDA_LAUNCH_BLOCKING'] = \"1\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5"
},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda', index=0)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"from torchtext.data import Field, TabularDataset, BucketIterator, Iterator\n",
"\n",
"import torch.nn as nn\n",
"from transformers import BertTokenizer, BertForSequenceClassification\n",
"\n",
"import torch.optim as optim\n",
"\n",
"from sklearn.metrics import accuracy_score, classification_report, confusion_matrix\n",
"import seaborn as sns\n",
"\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"device"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>title</th>\n",
" <th>text</th>\n",
" <th>label</th>\n",
" <th>titletext</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>You Can Smell Hillary’s Fear</td>\n",
" <td>Daniel Greenfield, a Shillman Journalism Fello...</td>\n",
" <td>FAKE</td>\n",
" <td>You Can Smell Hillary’s Fear Daniel Greenfield...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Watch The Exact Moment Paul Ryan Committed Pol...</td>\n",
" <td>Google Pinterest Digg Linkedin Reddit Stumbleu...</td>\n",
" <td>FAKE</td>\n",
" <td>Watch The Exact Moment Paul Ryan Committed Pol...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Kerry to go to Paris in gesture of sympathy</td>\n",
" <td>U.S. Secretary of State John F. Kerry said Mon...</td>\n",
" <td>REAL</td>\n",
" <td>Kerry to go to Paris in gesture of sympathy U....</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Bernie supporters on Twitter erupt in anger ag...</td>\n",
" <td>— Kaydee King (@KaydeeKing) November 9, 2016 T...</td>\n",
" <td>FAKE</td>\n",
" <td>Bernie supporters on Twitter erupt in anger ag...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>The Battle of New York: Why This Primary Matters</td>\n",
" <td>It's primary day in New York and front-runners...</td>\n",
" <td>REAL</td>\n",
" <td>The Battle of New York: Why This Primary Matte...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" title \\\n",
"0 You Can Smell Hillary’s Fear \n",
"1 Watch The Exact Moment Paul Ryan Committed Pol... \n",
"2 Kerry to go to Paris in gesture of sympathy \n",
"3 Bernie supporters on Twitter erupt in anger ag... \n",
"4 The Battle of New York: Why This Primary Matters \n",
"\n",
" text label \\\n",
"0 Daniel Greenfield, a Shillman Journalism Fello... FAKE \n",
"1 Google Pinterest Digg Linkedin Reddit Stumbleu... FAKE \n",
"2 U.S. Secretary of State John F. Kerry said Mon... REAL \n",
"3 — Kaydee King (@KaydeeKing) November 9, 2016 T... FAKE \n",
"4 It's primary day in New York and front-runners... REAL \n",
"\n",
" titletext \n",
"0 You Can Smell Hillary’s Fear Daniel Greenfield... \n",
"1 Watch The Exact Moment Paul Ryan Committed Pol... \n",
"2 Kerry to go to Paris in gesture of sympathy U.... \n",
"3 Bernie supporters on Twitter erupt in anger ag... \n",
"4 The Battle of New York: Why This Primary Matte... "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"news = pd.read_csv(\"./real-and-fake-news-dataset/news.csv\")\n",
"news.drop('Unnamed: 0', axis=1, inplace=True)\n",
"news['titletext'] = news['title'] + \" \" + news['text']\n",
"news.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# slice\n",
"news = news.head(500)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"news['titletext'] = news['titletext'].str.slice(0,128)\n",
"news['title'] = news['title'].str.slice(0,128)\n",
"news['text'] = news['text'].str.slice(0,128)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"news['label'] = news['label'].astype('category').cat.codes"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>title</th>\n",
" <th>text</th>\n",
" <th>label</th>\n",
" <th>titletext</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>You Can Smell Hillary’s Fear</td>\n",
" <td>Daniel Greenfield, a Shillman Journalism Fello...</td>\n",
" <td>0</td>\n",
" <td>You Can Smell Hillary’s Fear Daniel Greenfield...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Watch The Exact Moment Paul Ryan Committed Pol...</td>\n",
" <td>Google Pinterest Digg Linkedin Reddit Stumbleu...</td>\n",
" <td>0</td>\n",
" <td>Watch The Exact Moment Paul Ryan Committed Pol...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Kerry to go to Paris in gesture of sympathy</td>\n",
" <td>U.S. Secretary of State John F. Kerry said Mon...</td>\n",
" <td>1</td>\n",
" <td>Kerry to go to Paris in gesture of sympathy U....</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Bernie supporters on Twitter erupt in anger ag...</td>\n",
" <td>— Kaydee King (@KaydeeKing) November 9, 2016 T...</td>\n",
" <td>0</td>\n",
" <td>Bernie supporters on Twitter erupt in anger ag...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>The Battle of New York: Why This Primary Matters</td>\n",
" <td>It's primary day in New York and front-runners...</td>\n",
" <td>1</td>\n",
" <td>The Battle of New York: Why This Primary Matte...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" title \\\n",
"0 You Can Smell Hillary’s Fear \n",
"1 Watch The Exact Moment Paul Ryan Committed Pol... \n",
"2 Kerry to go to Paris in gesture of sympathy \n",
"3 Bernie supporters on Twitter erupt in anger ag... \n",
"4 The Battle of New York: Why This Primary Matters \n",
"\n",
" text label \\\n",
"0 Daniel Greenfield, a Shillman Journalism Fello... 0 \n",
"1 Google Pinterest Digg Linkedin Reddit Stumbleu... 0 \n",
"2 U.S. Secretary of State John F. Kerry said Mon... 1 \n",
"3 — Kaydee King (@KaydeeKing) November 9, 2016 T... 0 \n",
"4 It's primary day in New York and front-runners... 1 \n",
"\n",
" titletext \n",
"0 You Can Smell Hillary’s Fear Daniel Greenfield... \n",
"1 Watch The Exact Moment Paul Ryan Committed Pol... \n",
"2 Kerry to go to Paris in gesture of sympathy U.... \n",
"3 Bernie supporters on Twitter erupt in anger ag... \n",
"4 The Battle of New York: Why This Primary Matte... "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"news.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"X_train,X_test,y_train,y_test = train_test_split(news[['title','text','titletext']],news['label'],stratify=news['label'],test_size=0.3)\n",
"X_test,X_valid,y_test,y_valid = train_test_split(X_test,y_test,stratify=y_test,test_size=0.5)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Size: (350, 3) (350,)\n",
"Test Size: (75, 3) (75,)\n",
"Valid Size: (75, 3) (75,)\n"
]
}
],
"source": [
"print(f\"Train Size: {X_train.shape} {y_train.shape}\")\n",
"print(f\"Test Size: {X_test.shape} {y_test.shape}\")\n",
"print(f\"Valid Size: {X_valid.shape} {y_valid.shape}\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"X_train['label'] = y_train.values\n",
"X_test['label'] = y_test.values\n",
"X_valid['label'] = y_valid.values"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"X_train.to_csv(\"./real-and-fake-news-dataset/train.csv\", index=False)\n",
"X_test.to_csv(\"./real-and-fake-news-dataset/test.csv\", index=False)\n",
"X_valid.to_csv(\"./real-and-fake-news-dataset/valid.csv\", index=False)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
"_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a"
},
"outputs": [],
"source": [
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
"# Model parameter\n",
"MAX_SEQ_LEN = 128\n",
"PAD_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)\n",
"UNK_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)\n",
"\n",
"label_field = Field(sequential=False, use_vocab=False, batch_first=True, dtype=torch.int)\n",
"text_field = Field(use_vocab=False, tokenize=tokenizer.encode, lower=False, include_lengths=False, batch_first=True, fix_length=MAX_SEQ_LEN, pad_token=PAD_INDEX, unk_token=UNK_INDEX)\n",
"fields = [('title', text_field), ('text', text_field), ('titletext', text_field), ('label', label_field),]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note:\n",
"<span style='color:Red'>In order to use BERT tokenizer with TorchText, we have to set use_vocab=False and tokenize=tokenizer.encode. This will let TorchText know that we will not be building our own vocabulary using our dataset from scratch, but instead, use the pre-trained BERT tokenizer and its corresponding word-to-index mapping."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 428 ms, sys: 0 ns, total: 428 ms\n",
"Wall time: 427 ms\n"
]
}
],
"source": [
"%%time\n",
"\n",
"# TabularDataset\n",
"\n",
"train, valid, test = TabularDataset.splits(path=\"./real-and-fake-news-dataset\", train='train.csv', validation='valid.csv', test='test.csv', format='CSV', fields=fields, skip_header=True)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# Iterators\n",
"\n",
"train_iter = BucketIterator(train, batch_size=16, sort_key=lambda x: len(x.text),\n",
" device=device, train=True, sort=True, sort_within_batch=True)\n",
"valid_iter = BucketIterator(valid, batch_size=16, sort_key=lambda x: len(x.text),\n",
" device=device, train=True, sort=True, sort_within_batch=True)\n",
"test_iter = Iterator(test, batch_size=16, device=device, train=False, shuffle=False, sort=False)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"x = next(iter(train_iter))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\n",
"[torchtext.data.batch.Batch of size 16]\n",
"\t[.title]:[torch.cuda.LongTensor of size 16x128 (GPU 0)]\n",
"\t[.text]:[torch.cuda.LongTensor of size 16x128 (GPU 0)]\n",
"\t[.titletext]:[torch.cuda.LongTensor of size 16x128 (GPU 0)]\n",
"\t[.label]:[torch.cuda.IntTensor of size 16 (GPU 0)]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 101, 2076, 1996, ..., 0, 0, 0],\n",
" [ 101, 1996, 4883, ..., 0, 0, 0],\n",
" [ 101, 2343, 13857, ..., 0, 0, 0],\n",
" ...,\n",
" [ 101, 18520, 7207, ..., 0, 0, 0],\n",
" [ 101, 1996, 3784, ..., 0, 0, 0],\n",
" [ 101, 102, 0, ..., 0, 0, 0]], device='cuda:0')"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"text/plain": [
"tensor([[ 101, 8096, 24731, ..., 0, 0, 0],\n",
" [ 101, 15802, 2584, ..., 0, 0, 0],\n",
" [ 101, 8112, 14616, ..., 0, 0, 0],\n",
" ...,\n",
" [ 101, 18520, 17727, ..., 0, 0, 0],\n",
" [ 101, 8112, 1010, ..., 0, 0, 0],\n",
" [ 101, 1996, 8115, ..., 0, 0, 0]], device='cuda:0')"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"text/plain": [
"tensor([[ 101, 8096, 24731, ..., 0, 0, 0],\n",
" [ 101, 15802, 2584, ..., 0, 0, 0],\n",
" [ 101, 8112, 14616, ..., 0, 0, 0],\n",
" ...,\n",
" [ 101, 18520, 17727, ..., 0, 0, 0],\n",
" [ 101, 8112, 1010, ..., 0, 0, 0],\n",
" [ 101, 1996, 8115, ..., 0, 0, 0]], device='cuda:0')"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"text/plain": [
"tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0], device='cuda:0',\n",
" dtype=torch.int32)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.text\n",
"x.titletext\n",
"x.title\n",
"x.label"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"class BERT(nn.Module):\n",
" def __init__(self):\n",
" super(BERT,self).__init__()\n",
" options_name = 'bert-base-uncased'\n",
" self.encoder = BertForSequenceClassification.from_pretrained(options_name)\n",
" \n",
" \n",
" def forward(self, text, label):\n",
" loss, text_fea = self.encoder(text, labels=label)[:2]\n",
"\n",
" return loss, text_fea"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Functions to Save, Load Checkpoint and Metrics"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# Save and Load Functions\n",
"\n",
"def save_checkpoint(save_path, model, valid_loss):\n",
" if save_path == None:\n",
" return\n",
" state_dict = {'model_state_dict': model.state_dict(),\n",
" 'valid_loss': valid_loss}\n",
" torch.save(state_dict, save_path)\n",
" print(f'Model saved to ==> {save_path}')\n",
"\n",
"def load_checkpoint(load_path, model):\n",
" if load_path==None:\n",
" return\n",
" \n",
" state_dict = torch.load(load_path, map_location=device)\n",
" print(f'Model loaded from <== {load_path}')\n",
" \n",
" model.load_state_dict(state_dict['model_state_dict'])\n",
" return state_dict['valid_loss']"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):\n",
" if save_path == None:\n",
" return\n",
" state_dict = {'train_loss_list': train_loss_list,\n",
" 'valid_loss_list': valid_loss_list,\n",
" 'global_steps_list': global_steps_list}\n",
" torch.save(state_dict, save_path)\n",
" print(f'Model saved to ==> {save_path}')\n",
"\n",
"\n",
"def load_metrics(load_path):\n",
" if load_path==None:\n",
" return\n",
" state_dict = torch.load(load_path, map_location=device)\n",
" print(f'Model loaded from <== {load_path}')\n",
" return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def train(model, optimizer, critertion=nn.BCELoss(),train_loader=train_iter,valid_loader=valid_iter,num_epochs=5\n",
" ,eval_every = len(train_iter) // 2,file_path = \"\",best_valid_loss = float(\"Inf\")):\n",
" # initialize running values\n",
" running_loss = 0.0\n",
" valid_running_loss = 0.0\n",
" global_step = 0\n",
" train_loss_list = []\n",
" valid_loss_list = []\n",
" global_steps_list = []\n",
" \n",
" model.train()\n",
" for epoch in range(num_epochs):\n",
" for (title, text, titletext, labels), _ in train_loader:\n",
" labels = labels.type(torch.LongTensor) \n",
" labels = labels.to(device)\n",
" \n",
" titletext = titletext.type(torch.LongTensor) \n",
" titletext = titletext.to(device)\n",
" #print(labels.shape)\n",
" #print(titletext.shape)\n",
" \n",
" output = model(titletext, labels)\n",
" loss, _ = output\n",
" \n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" running_loss += loss.item()\n",
" global_step += 1\n",
" \n",
" if global_step % eval_every == 0:\n",
" model.eval()\n",
" with torch.no_grad(): \n",
" # validation loop\n",
" for (title, text, titletext, labels), _ in valid_loader:\n",
" labels = labels.type(torch.LongTensor) \n",
" labels = labels.to(device)\n",
" titletext = titletext.type(torch.LongTensor) \n",
" titletext = titletext.to(device)\n",
" output = model(titletext, labels)\n",
" loss, _ = output\n",
" \n",
" valid_running_loss += loss.item()\n",
" \n",
" # evaluation\n",
" average_train_loss = running_loss / eval_every\n",
" average_valid_loss = valid_running_loss / len(valid_loader)\n",
" train_loss_list.append(average_train_loss)\n",
" valid_loss_list.append(average_valid_loss)\n",
" global_steps_list.append(global_step)\n",
" \n",
" # resetting running values\n",
" running_loss = 0.0 \n",
" valid_running_loss = 0.0\n",
" model.train()\n",
"\n",
" # print progress\n",
" print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Valid Loss: {:.4f}'\n",
" .format(epoch+1, num_epochs, global_step, num_epochs*len(train_loader),\n",
" average_train_loss, average_valid_loss))\n",
" \n",
" # checkpoint\n",
" if best_valid_loss > average_valid_loss:\n",
" best_valid_loss = average_valid_loss\n",
" #save_checkpoint(file_path + model.pt, model, best_valid_loss)\n",
" #save_metrics(file_path + metrics.pt, train_loss_list, valid_loss_list, global_steps_list)\n",
" \n",
" save_metrics(file_path + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list)\n",
" print('Finished Training!')"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n",
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/5], Step [11/110], Train Loss: 0.6670, Valid Loss: 0.7580\n",
"Epoch [1/5], Step [22/110], Train Loss: 0.7145, Valid Loss: 0.7330\n",
"Epoch [2/5], Step [33/110], Train Loss: 0.7198, Valid Loss: 0.7107\n",
"Epoch [2/5], Step [44/110], Train Loss: 0.7288, Valid Loss: 0.6965\n",
"Epoch [3/5], Step [55/110], Train Loss: 0.6772, Valid Loss: 0.7183\n",
"Epoch [3/5], Step [66/110], Train Loss: 0.7014, Valid Loss: 0.7165\n",
"Epoch [4/5], Step [77/110], Train Loss: 0.7007, Valid Loss: 0.6947\n",
"Epoch [4/5], Step [88/110], Train Loss: 0.7400, Valid Loss: 0.6920\n",
"Epoch [5/5], Step [99/110], Train Loss: 0.7089, Valid Loss: 0.6913\n",
"Epoch [5/5], Step [110/110], Train Loss: 0.7222, Valid Loss: 0.6897\n",
"Model saved to ==> metrics.pt\n",
"Finished Training!\n"
]
}
],
"source": [
"model = BERT().to(device)\n",
"optimizer = optim.Adam(model.parameters(), lr=2e-5)\n",
"\n",
"train(model=model, optimizer=optimizer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {
"height": "calc(100% - 180px)",
"left": "10px",
"top": "150px",
"width": "248px"
},
"toc_section_display": true,
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment