Last active
April 13, 2024 11:46
-
-
Save jamescalam/55daf50c8da9eb3a7c18de058bc139a3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "narrative-warner", | |
"metadata": {}, | |
"source": [ | |
"# Fine-Tuning With SQuAD 2.0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "theoretical-confirmation", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import requests\n", | |
"import json" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "promising-stocks", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"if not os.path.exists('../data/benchmarks/squad'):\n", | |
" os.mkdir('../data/benchmarks/squad')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "beautiful-composition", | |
"metadata": {}, | |
"source": [ | |
"---\n", | |
"# Get and Prepare Data\n", | |
"\n", | |
"## Download SQuAD data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "laughing-novel", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"url = 'https://rajpurkar.github.io/SQuAD-explorer/dataset/'\n", | |
"res = requests.get(f'{url}train-v2.0.json')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "progressive-spice", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for file in ['train-v2.0.json', 'dev-v2.0.json']:\n", | |
" res = requests.get(f'{url}{file}')\n", | |
" # write to file\n", | |
" with open(f'../data/benchmarks/squad/{file}', 'wb') as f:\n", | |
" for chunk in res.iter_content(chunk_size=4):\n", | |
" f.write(chunk)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fleet-rolling", | |
"metadata": {}, | |
"source": [ | |
"## Read" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "touched-terry", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def read_squad(path):\n", | |
" with open(path, 'rb') as f:\n", | |
" squad_dict = json.load(f)\n", | |
"\n", | |
" # initialize lists for contexts, questions, and answers\n", | |
" contexts = []\n", | |
" questions = []\n", | |
" answers = []\n", | |
" # iterate through all data in squad data\n", | |
" for group in squad_dict['data']:\n", | |
" for passage in group['paragraphs']:\n", | |
" context = passage['context']\n", | |
" for qa in passage['qas']:\n", | |
" question = qa['question']\n", | |
" if 'plausible_answers' in qa.keys():\n", | |
" access = 'plausible_answers'\n", | |
" else:\n", | |
" access = 'answers'\n", | |
" for answer in qa['answers']:\n", | |
" # append data to lists\n", | |
" contexts.append(context)\n", | |
" questions.append(question)\n", | |
" answers.append(answer)\n", | |
" # return formatted data lists\n", | |
" return contexts, questions, answers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "little-treasury", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_contexts, train_questions, train_answers = read_squad('../data/benchmarks/squad/train-v2.0.json')\n", | |
"val_contexts, val_questions, val_answers = read_squad('../data/benchmarks/squad/dev-v2.0.json')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "handed-zealand", | |
"metadata": {}, | |
"source": [ | |
"## Prepare" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "precious-windows", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def add_end_idx(answers, contexts):\n", | |
" # loop through each answer-context pair\n", | |
" for answer, context in zip(answers, contexts):\n", | |
" # gold_text refers to the answer we are expecting to find in context\n", | |
" gold_text = answer['text']\n", | |
" # we already know the start index\n", | |
" start_idx = answer['answer_start']\n", | |
" # and ideally this would be the end index...\n", | |
" end_idx = start_idx + len(gold_text)\n", | |
"\n", | |
" # ...however, sometimes squad answers are off by a character or two\n", | |
" if context[start_idx:end_idx] == gold_text:\n", | |
" # if the answer is not off :)\n", | |
" answer['answer_end'] = end_idx\n", | |
" else:\n", | |
" for n in [1, 2]:\n", | |
" if context[start_idx-n:end_idx-n] == gold_text:\n", | |
" # this means the answer is off by 'n' tokens\n", | |
" answer['answer_start'] = start_idx - n\n", | |
" answer['answer_end'] = end_idx - n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "detailed-karen", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"add_end_idx(train_answers, train_contexts)\n", | |
"add_end_idx(val_answers, val_contexts)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cheap-pharmacology", | |
"metadata": {}, | |
"source": [ | |
"## Encode" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "voluntary-effect", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from transformers import DistilBertTokenizerFast\n", | |
"tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')\n", | |
"\n", | |
"train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)\n", | |
"val_encodings = tokenizer(val_contexts, val_questions, truncation=True, padding=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "behind-technician", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def add_token_positions(encodings, answers):\n", | |
" # initialize lists to contain the token indices of answer start/end\n", | |
" start_positions = []\n", | |
" end_positions = []\n", | |
" for i in range(len(answers)):\n", | |
" # append start/end token position using char_to_token method\n", | |
" start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))\n", | |
" end_positions.append(encodings.char_to_token(i, answers[i]['answer_end']))\n", | |
"\n", | |
" # if start position is None, the answer passage has been truncated\n", | |
" if start_positions[-1] is None:\n", | |
" start_positions[-1] = tokenizer.model_max_length\n", | |
" # end position cannot be found, char_to_token found space, so shift one token forward\n", | |
" go_back = 1\n", | |
" while end_positions[-1] is None:\n", | |
" end_positions[-1] = encodings.char_to_token(i, answers[i]['answer_end']-go_back)\n", | |
" go_back +=1\n", | |
" # update our encodings object with the new token-based start/end positions\n", | |
" encodings.update({'start_positions': start_positions, 'end_positions': end_positions})\n", | |
"\n", | |
"# apply function to our data\n", | |
"add_token_positions(train_encodings, train_answers)\n", | |
"add_token_positions(val_encodings, val_answers)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "specified-daughter", | |
"metadata": {}, | |
"source": [ | |
"---\n", | |
"\n", | |
"# PyTorch Fine-tuning" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "recognized-proceeding", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"\n", | |
"class SquadDataset(torch.utils.data.Dataset):\n", | |
" def __init__(self, encodings):\n", | |
" self.encodings = encodings\n", | |
"\n", | |
" def __getitem__(self, idx):\n", | |
" return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n", | |
"\n", | |
" def __len__(self):\n", | |
" return len(self.encodings.input_ids)\n", | |
"\n", | |
"train_dataset = SquadDataset(train_encodings)\n", | |
"val_dataset = SquadDataset(val_encodings)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "alive-qatar", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from transformers import DistilBertForQuestionAnswering\n", | |
"model = DistilBertForQuestionAnswering.from_pretrained(\"distilbert-base-uncased\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "spectacular-course", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torch.utils.data import DataLoader\n", | |
"from transformers import AdamW\n", | |
"from tqdm import tqdm\n", | |
"\n", | |
"# setup GPU/CPU\n", | |
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", | |
"# move model over to detected device\n", | |
"model.to(device)\n", | |
"# activate training mode of model\n", | |
"model.train()\n", | |
"# initialize adam optimizer with weight decay (reduces chance of overfitting)\n", | |
"optim = AdamW(model.parameters(), lr=5e-5)\n", | |
"\n", | |
"# initialize data loader for training data\n", | |
"train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)\n", | |
"\n", | |
"for epoch in range(3):\n", | |
" # set model to train mode\n", | |
" model.train()\n", | |
" # setup loop (we use tqdm for the progress bar)\n", | |
" loop = tqdm(train_loader, leave=True)\n", | |
" for batch in loop:\n", | |
" # initialize calculated gradients (from prev step)\n", | |
" optim.zero_grad()\n", | |
" # pull all the tensor batches required for training\n", | |
" input_ids = batch['input_ids'].to(device)\n", | |
" attention_mask = batch['attention_mask'].to(device)\n", | |
" start_positions = batch['start_positions'].to(device)\n", | |
" end_positions = batch['end_positions'].to(device)\n", | |
" # train model on batch and return outputs (incl. loss)\n", | |
" outputs = model(input_ids, attention_mask=attention_mask,\n", | |
" start_positions=start_positions,\n", | |
" end_positions=end_positions)\n", | |
" # extract loss\n", | |
" loss = outputs[0]\n", | |
" # calculate loss for every parameter that needs grad update\n", | |
" loss.backward()\n", | |
" # update parameters\n", | |
" optim.step()\n", | |
" # print relevant info to progress bar\n", | |
" loop.set_description(f'Epoch {epoch}')\n", | |
" loop.set_postfix(loss=loss.item())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "proper-recruitment", | |
"metadata": {}, | |
"source": [ | |
"## Save Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "severe-brooks", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model_path = 'models/distilbert-custom'\n", | |
"model.save_pretrained(model_path)\n", | |
"tokenizer.save_pretrained(model_path)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "otherwise-religion", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# switch model out of training mode\n", | |
"model.eval()\n", | |
"\n", | |
"#val_sampler = SequentialSampler(val_dataset)\n", | |
"val_loader = DataLoader(val_dataset, batch_size=16)\n", | |
"\n", | |
"acc = []\n", | |
"\n", | |
"# initialize loop for progress bar\n", | |
"loop = tqdm(val_loader)\n", | |
"# loop through batches\n", | |
"for batch in loop:\n", | |
" # we don't need to calculate gradients as we're not training\n", | |
" with torch.no_grad():\n", | |
" # pull batched items from loader\n", | |
" input_ids = batch['input_ids'].to(device)\n", | |
" attention_mask = batch['attention_mask'].to(device)\n", | |
" start_true = batch['start_positions'].to(device)\n", | |
" end_true = batch['end_positions'].to(device)\n", | |
" # make predictions\n", | |
" outputs = model(input_ids, attention_mask=attention_mask)\n", | |
" # pull preds out\n", | |
" start_pred = torch.argmax(outputs['start_logits'], dim=1)\n", | |
" end_pred = torch.argmax(outputs['end_logits'], dim=1)\n", | |
" # calculate accuracy for both and append to accuracy list\n", | |
" acc.append(((start_pred == start_true).sum()/len(start_pred)).item())\n", | |
" acc.append(((end_pred == end_true).sum()/len(end_pred)).item())\n", | |
"# calculate average accuracy in total\n", | |
"acc = sum(acc)/len(acc)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 96, | |
"id": "pressed-request", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"T/F\tstart\tend\n", | |
"\n", | |
"true\t194\t196\n", | |
"pred\t187\t196\n", | |
"\n", | |
"true\t194\t196\n", | |
"pred\t187\t196\n", | |
"\n", | |
"true\t194\t196\n", | |
"pred\t187\t196\n", | |
"\n", | |
"true\t20\t21\n", | |
"pred\t20\t32\n", | |
"\n", | |
"true\t20\t21\n", | |
"pred\t20\t32\n", | |
"\n", | |
"true\t20\t21\n", | |
"pred\t20\t32\n", | |
"\n", | |
"true\t13\t14\n", | |
"pred\t20\t32\n", | |
"\n", | |
"true\t10\t14\n", | |
"pred\t20\t32\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"T/F\\tstart\\tend\\n\")\n", | |
"for i in range(len(start_true)):\n", | |
" print(f\"true\\t{start_true[i]}\\t{end_true[i]}\\n\"\n", | |
" f\"pred\\t{start_pred[i]}\\t{end_pred[i]}\\n\")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "NLP", | |
"language": "python", | |
"name": "env" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
reduce batch size
Is a full fine-tuning happening here? I mean are all the weights and biases of the base model getting updated here?
yes. It was updated according to given datasets
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi there and thank you for your awesome example!
While trying to run this in colab (free edition), on the fine-tunning block my session keeps crashing as I run out of RAM.
Any ideas on how to solve this issue?
Thank you in advance, and again thank you for your great example!