Skip to content

Instantly share code, notes, and snippets.

@jamescalam
Last active April 13, 2024 11:46
Show Gist options
  • Save jamescalam/55daf50c8da9eb3a7c18de058bc139a3 to your computer and use it in GitHub Desktop.
Save jamescalam/55daf50c8da9eb3a7c18de058bc139a3 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "narrative-warner",
"metadata": {},
"source": [
"# Fine-Tuning With SQuAD 2.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "theoretical-confirmation",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import requests\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "promising-stocks",
"metadata": {},
"outputs": [],
"source": [
"if not os.path.exists('../data/benchmarks/squad'):\n",
" os.mkdir('../data/benchmarks/squad')"
]
},
{
"cell_type": "markdown",
"id": "beautiful-composition",
"metadata": {},
"source": [
"---\n",
"# Get and Prepare Data\n",
"\n",
"## Download SQuAD data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "laughing-novel",
"metadata": {},
"outputs": [],
"source": [
"url = 'https://rajpurkar.github.io/SQuAD-explorer/dataset/'\n",
"res = requests.get(f'{url}train-v2.0.json')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "progressive-spice",
"metadata": {},
"outputs": [],
"source": [
"for file in ['train-v2.0.json', 'dev-v2.0.json']:\n",
" res = requests.get(f'{url}{file}')\n",
" # write to file\n",
" with open(f'../data/benchmarks/squad/{file}', 'wb') as f:\n",
" for chunk in res.iter_content(chunk_size=4):\n",
" f.write(chunk)"
]
},
{
"cell_type": "markdown",
"id": "fleet-rolling",
"metadata": {},
"source": [
"## Read"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "touched-terry",
"metadata": {},
"outputs": [],
"source": [
"def read_squad(path):\n",
" with open(path, 'rb') as f:\n",
" squad_dict = json.load(f)\n",
"\n",
" # initialize lists for contexts, questions, and answers\n",
" contexts = []\n",
" questions = []\n",
" answers = []\n",
" # iterate through all data in squad data\n",
" for group in squad_dict['data']:\n",
" for passage in group['paragraphs']:\n",
" context = passage['context']\n",
" for qa in passage['qas']:\n",
" question = qa['question']\n",
" if 'plausible_answers' in qa.keys():\n",
" access = 'plausible_answers'\n",
" else:\n",
" access = 'answers'\n",
" for answer in qa['answers']:\n",
" # append data to lists\n",
" contexts.append(context)\n",
" questions.append(question)\n",
" answers.append(answer)\n",
" # return formatted data lists\n",
" return contexts, questions, answers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "little-treasury",
"metadata": {},
"outputs": [],
"source": [
"train_contexts, train_questions, train_answers = read_squad('../data/benchmarks/squad/train-v2.0.json')\n",
"val_contexts, val_questions, val_answers = read_squad('../data/benchmarks/squad/dev-v2.0.json')"
]
},
{
"cell_type": "markdown",
"id": "handed-zealand",
"metadata": {},
"source": [
"## Prepare"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "precious-windows",
"metadata": {},
"outputs": [],
"source": [
"def add_end_idx(answers, contexts):\n",
" # loop through each answer-context pair\n",
" for answer, context in zip(answers, contexts):\n",
" # gold_text refers to the answer we are expecting to find in context\n",
" gold_text = answer['text']\n",
" # we already know the start index\n",
" start_idx = answer['answer_start']\n",
" # and ideally this would be the end index...\n",
" end_idx = start_idx + len(gold_text)\n",
"\n",
" # ...however, sometimes squad answers are off by a character or two\n",
" if context[start_idx:end_idx] == gold_text:\n",
" # if the answer is not off :)\n",
" answer['answer_end'] = end_idx\n",
" else:\n",
" for n in [1, 2]:\n",
" if context[start_idx-n:end_idx-n] == gold_text:\n",
" # this means the answer is off by 'n' tokens\n",
" answer['answer_start'] = start_idx - n\n",
" answer['answer_end'] = end_idx - n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "detailed-karen",
"metadata": {},
"outputs": [],
"source": [
"add_end_idx(train_answers, train_contexts)\n",
"add_end_idx(val_answers, val_contexts)"
]
},
{
"cell_type": "markdown",
"id": "cheap-pharmacology",
"metadata": {},
"source": [
"## Encode"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "voluntary-effect",
"metadata": {},
"outputs": [],
"source": [
"from transformers import DistilBertTokenizerFast\n",
"tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')\n",
"\n",
"train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)\n",
"val_encodings = tokenizer(val_contexts, val_questions, truncation=True, padding=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "behind-technician",
"metadata": {},
"outputs": [],
"source": [
"def add_token_positions(encodings, answers):\n",
" # initialize lists to contain the token indices of answer start/end\n",
" start_positions = []\n",
" end_positions = []\n",
" for i in range(len(answers)):\n",
" # append start/end token position using char_to_token method\n",
" start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))\n",
" end_positions.append(encodings.char_to_token(i, answers[i]['answer_end']))\n",
"\n",
" # if start position is None, the answer passage has been truncated\n",
" if start_positions[-1] is None:\n",
" start_positions[-1] = tokenizer.model_max_length\n",
" # end position cannot be found, char_to_token found space, so shift one token forward\n",
" go_back = 1\n",
" while end_positions[-1] is None:\n",
" end_positions[-1] = encodings.char_to_token(i, answers[i]['answer_end']-go_back)\n",
" go_back +=1\n",
" # update our encodings object with the new token-based start/end positions\n",
" encodings.update({'start_positions': start_positions, 'end_positions': end_positions})\n",
"\n",
"# apply function to our data\n",
"add_token_positions(train_encodings, train_answers)\n",
"add_token_positions(val_encodings, val_answers)"
]
},
{
"cell_type": "markdown",
"id": "specified-daughter",
"metadata": {},
"source": [
"---\n",
"\n",
"# PyTorch Fine-tuning"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "recognized-proceeding",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"class SquadDataset(torch.utils.data.Dataset):\n",
" def __init__(self, encodings):\n",
" self.encodings = encodings\n",
"\n",
" def __getitem__(self, idx):\n",
" return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
"\n",
" def __len__(self):\n",
" return len(self.encodings.input_ids)\n",
"\n",
"train_dataset = SquadDataset(train_encodings)\n",
"val_dataset = SquadDataset(val_encodings)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "alive-qatar",
"metadata": {},
"outputs": [],
"source": [
"from transformers import DistilBertForQuestionAnswering\n",
"model = DistilBertForQuestionAnswering.from_pretrained(\"distilbert-base-uncased\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "spectacular-course",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"from transformers import AdamW\n",
"from tqdm import tqdm\n",
"\n",
"# setup GPU/CPU\n",
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
"# move model over to detected device\n",
"model.to(device)\n",
"# activate training mode of model\n",
"model.train()\n",
"# initialize adam optimizer with weight decay (reduces chance of overfitting)\n",
"optim = AdamW(model.parameters(), lr=5e-5)\n",
"\n",
"# initialize data loader for training data\n",
"train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)\n",
"\n",
"for epoch in range(3):\n",
" # set model to train mode\n",
" model.train()\n",
" # setup loop (we use tqdm for the progress bar)\n",
" loop = tqdm(train_loader, leave=True)\n",
" for batch in loop:\n",
" # initialize calculated gradients (from prev step)\n",
" optim.zero_grad()\n",
" # pull all the tensor batches required for training\n",
" input_ids = batch['input_ids'].to(device)\n",
" attention_mask = batch['attention_mask'].to(device)\n",
" start_positions = batch['start_positions'].to(device)\n",
" end_positions = batch['end_positions'].to(device)\n",
" # train model on batch and return outputs (incl. loss)\n",
" outputs = model(input_ids, attention_mask=attention_mask,\n",
" start_positions=start_positions,\n",
" end_positions=end_positions)\n",
" # extract loss\n",
" loss = outputs[0]\n",
" # calculate loss for every parameter that needs grad update\n",
" loss.backward()\n",
" # update parameters\n",
" optim.step()\n",
" # print relevant info to progress bar\n",
" loop.set_description(f'Epoch {epoch}')\n",
" loop.set_postfix(loss=loss.item())"
]
},
{
"cell_type": "markdown",
"id": "proper-recruitment",
"metadata": {},
"source": [
"## Save Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "severe-brooks",
"metadata": {},
"outputs": [],
"source": [
"model_path = 'models/distilbert-custom'\n",
"model.save_pretrained(model_path)\n",
"tokenizer.save_pretrained(model_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "otherwise-religion",
"metadata": {},
"outputs": [],
"source": [
"# switch model out of training mode\n",
"model.eval()\n",
"\n",
"#val_sampler = SequentialSampler(val_dataset)\n",
"val_loader = DataLoader(val_dataset, batch_size=16)\n",
"\n",
"acc = []\n",
"\n",
"# initialize loop for progress bar\n",
"loop = tqdm(val_loader)\n",
"# loop through batches\n",
"for batch in loop:\n",
" # we don't need to calculate gradients as we're not training\n",
" with torch.no_grad():\n",
" # pull batched items from loader\n",
" input_ids = batch['input_ids'].to(device)\n",
" attention_mask = batch['attention_mask'].to(device)\n",
" start_true = batch['start_positions'].to(device)\n",
" end_true = batch['end_positions'].to(device)\n",
" # make predictions\n",
" outputs = model(input_ids, attention_mask=attention_mask)\n",
" # pull preds out\n",
" start_pred = torch.argmax(outputs['start_logits'], dim=1)\n",
" end_pred = torch.argmax(outputs['end_logits'], dim=1)\n",
" # calculate accuracy for both and append to accuracy list\n",
" acc.append(((start_pred == start_true).sum()/len(start_pred)).item())\n",
" acc.append(((end_pred == end_true).sum()/len(end_pred)).item())\n",
"# calculate average accuracy in total\n",
"acc = sum(acc)/len(acc)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"id": "pressed-request",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"T/F\tstart\tend\n",
"\n",
"true\t194\t196\n",
"pred\t187\t196\n",
"\n",
"true\t194\t196\n",
"pred\t187\t196\n",
"\n",
"true\t194\t196\n",
"pred\t187\t196\n",
"\n",
"true\t20\t21\n",
"pred\t20\t32\n",
"\n",
"true\t20\t21\n",
"pred\t20\t32\n",
"\n",
"true\t20\t21\n",
"pred\t20\t32\n",
"\n",
"true\t13\t14\n",
"pred\t20\t32\n",
"\n",
"true\t10\t14\n",
"pred\t20\t32\n",
"\n"
]
}
],
"source": [
"print(\"T/F\\tstart\\tend\\n\")\n",
"for i in range(len(start_true)):\n",
" print(f\"true\\t{start_true[i]}\\t{end_true[i]}\\n\"\n",
" f\"pred\\t{start_pred[i]}\\t{end_pred[i]}\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "NLP",
"language": "python",
"name": "env"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@PiGnotus
Copy link

Hi there and thank you for your awesome example!
While trying to run this in colab (free edition), on the fine-tunning block my session keeps crashing as I run out of RAM.
Any ideas on how to solve this issue?
Thank you in advance, and again thank you for your great example!

@k-praveen-trellis
Copy link

reduce batch size

@ash-rulz
Copy link

ash-rulz commented Dec 5, 2023

Is a full fine-tuning happening here? I mean are all the weights and biases of the base model getting updated here?

@dibyendubiswas1998
Copy link

yes. It was updated according to given datasets

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment