Skip to content

Instantly share code, notes, and snippets.

@xhluca
Created October 8, 2022 04:57
Show Gist options
  • Save xhluca/28181468e3907145027969a1003ae929 to your computer and use it in GitHub Desktop.
Save xhluca/28181468e3907145027969a1003ae929 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "792293ef-bc9c-4955-a865-9a8084b7f05f",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"\n",
"import transformers as hft\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c7703674-614c-48b9-9a8d-02dc0f5c2f38",
"metadata": {},
"outputs": [],
"source": [
"class NQDataset(torch.utils.data.Dataset):\n",
" def __init__(self, data: list):\n",
" self.data = data\n",
" \n",
" def __len__(self):\n",
" return len(self.data)\n",
" \n",
" def __getitem__(self, idx):\n",
" return (\n",
" self.data[idx]['question'],\n",
" self.data[idx]['positive_ctxs'][0]['title'],\n",
" self.data[idx]['positive_ctxs'][0]['text'],\n",
" )\n",
"\n",
"def get_schedule_linear(\n",
" optimizer,\n",
" warmup_steps,\n",
" total_training_steps,\n",
" steps_shift=0,\n",
" last_epoch=-1,\n",
"):\n",
" \"\"\"\n",
" Create a schedule with a learning rate that decreases linearly after\n",
" linearly increasing during a warmup period.\n",
" \n",
" Source: https://github.com/facebookresearch/DPR/blob/1ee31c6c53/dpr/utils/model_utils.py\n",
" \"\"\"\n",
" def lr_lambda(current_step):\n",
" current_step += steps_shift\n",
" if current_step < warmup_steps:\n",
" return float(current_step) / float(max(1, warmup_steps))\n",
" return max(\n",
" 1e-7,\n",
" float(total_training_steps - current_step) / float(max(1, total_training_steps - warmup_steps)),\n",
" )\n",
"\n",
" return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)\n",
"\n",
"def criterion(S, target=None):\n",
" softS = F.log_softmax(S, dim=1)\n",
" if target is None:\n",
" target = torch.arange(0, S.shape[0])\n",
"\n",
" target = target.to(softS.device)\n",
" loss = F.nll_loss(softS, target, reduction=\"mean\")\n",
"\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "66f6c6cb-91db-455e-beb6-7dccb1962537",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 128\n",
"warmup_steps = 1237\n",
"num_epochs = 40\n",
"learning_rate = 2e-5 # 2e-5 for NQ, 1e-5 for other datasets\n",
"max_length = 256\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "40018b4e-db1e-4fa9-9f99-26292cd50119",
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"q_tokenizer = hft.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
"ctx_tokenizer = hft.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
"\n",
"q_encoder = hft.AutoModel.from_pretrained(\"bert-base-uncased\").to(device)\n",
"ctx_encoder = hft.AutoModel.from_pretrained(\"bert-base-uncased\").to(device)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "17fe0612-88d0-4013-88d7-29ba31d96873",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 33.6 s, sys: 11.3 s, total: 44.9 s\n",
"Wall time: 44.9 s\n"
]
}
],
"source": [
"%%time\n",
"train_json = json.load(open('data/biencoder-nq-train.json'))\n",
"valid_json = json.load(open('data/biencoder-nq-dev.json'))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5cd6e088-b706-469c-b95c-4f664f1ed962",
"metadata": {},
"outputs": [],
"source": [
"train_loader = torch.utils.data.DataLoader(\n",
" NQDataset(train_json), batch_size=batch_size, shuffle=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "beb5f7fd-9778-4e7e-8fd5-138041085fea",
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.AdamW(\n",
" [{\"params\": ctx_encoder.parameters()}, {\"params\": q_encoder.parameters()}],\n",
" lr=learning_rate, \n",
" weight_decay=0.0,\n",
")\n",
"total_training_steps = num_epochs * len(train_loader)\n",
"scheduler = get_schedule_linear(\n",
" optimizer, warmup_steps=warmup_steps, total_training_steps=total_training_steps\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5d99c605-32a8-43f8-9e06-53d4b2219318",
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch #0: 100%|██████████| 460/460 [07:38<00:00, 1.00it/s, loss=4.66]\n",
"Epoch #1: 100%|██████████| 460/460 [07:38<00:00, 1.00it/s, loss=2.72]\n",
"Epoch #2: 100%|██████████| 460/460 [07:35<00:00, 1.01it/s, loss=0.653]\n",
"Epoch #3: 100%|██████████| 460/460 [07:36<00:00, 1.01it/s, loss=0.406]\n",
"Epoch #4: 100%|██████████| 460/460 [07:35<00:00, 1.01it/s, loss=0.381] \n",
"Epoch #5: 100%|██████████| 460/460 [07:35<00:00, 1.01it/s, loss=0.242] \n",
"Epoch #6: 100%|██████████| 460/460 [07:32<00:00, 1.02it/s, loss=0.224] \n",
"Epoch #7: 100%|██████████| 460/460 [07:34<00:00, 1.01it/s, loss=0.0726]\n",
"Epoch #8: 100%|██████████| 460/460 [07:34<00:00, 1.01it/s, loss=0.0708]\n",
"Epoch #9: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.11] \n",
"Epoch #10: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.0283] \n",
"Epoch #11: 100%|██████████| 460/460 [07:32<00:00, 1.02it/s, loss=0.0475] \n",
"Epoch #12: 100%|██████████| 460/460 [07:34<00:00, 1.01it/s, loss=0.0762] \n",
"Epoch #13: 73%|███████▎ | 334/460 [05:29<02:04, 1.01it/s, loss=0.019] IOPub message rate exceeded.\n",
"The Jupyter server will temporarily stop sending output\n",
"to the client in order to avoid crashing it.\n",
"To change this limit, set the config variable\n",
"`--ServerApp.iopub_msg_rate_limit`.\n",
"\n",
"Current values:\n",
"ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
"ServerApp.rate_limit_window=3.0 (secs)\n",
"\n",
"Epoch #25: 100%|██████████| 460/460 [07:34<00:00, 1.01it/s, loss=0.0115] \n",
"Epoch #26: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.00752] \n",
"Epoch #27: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.0207] \n",
"Epoch #28: 27%|██▋ | 125/460 [02:04<05:30, 1.01it/s, loss=0.00159] IOPub message rate exceeded.\n",
"The Jupyter server will temporarily stop sending output\n",
"to the client in order to avoid crashing it.\n",
"To change this limit, set the config variable\n",
"`--ServerApp.iopub_msg_rate_limit`.\n",
"\n",
"Current values:\n",
"ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
"ServerApp.rate_limit_window=3.0 (secs)\n",
"\n",
"Epoch #39: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.000152]\n"
]
}
],
"source": [
"results = {}\n",
"\n",
"for epoch_num in range(num_epochs):\n",
" results[epoch_num] = []\n",
" \n",
" ctx_encoder.train()\n",
" q_encoder.train()\n",
"\n",
" pbar = tqdm(train_loader, desc=f'Epoch #{epoch_num}')\n",
" \n",
" for q, t, c in pbar:\n",
" tokenizer_kwargs = dict(max_length=max_length, return_tensors=\"pt\", truncation=True, padding=True)\n",
" queries = q_tokenizer(list(q), **tokenizer_kwargs).to(device)\n",
" passages = ctx_tokenizer(list(t), list(c), **tokenizer_kwargs).to(device)\n",
"\n",
" optimizer.zero_grad()\n",
"\n",
" Q = q_encoder(**queries).pooler_output\n",
" P = ctx_encoder(**passages).pooler_output.to(Q.device)\n",
" S = torch.mm(Q, P.T)\n",
"\n",
" loss = criterion(S)\n",
"\n",
" loss.backward()\n",
" \n",
" optimizer.step()\n",
" scheduler.step()\n",
"\n",
" pbar.set_postfix({'loss': loss.item()})\n",
" \n",
" results[epoch_num].append(loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "32bee074-3854-4e4a-8b18-9228f03d97bb",
"metadata": {},
"outputs": [],
"source": [
"os.makedirs('models/', exist_ok=True)\n",
"\n",
"with open('models/results.json', 'w') as f:\n",
" json.dump(results, f)\n",
"\n",
"ctx_encoder.save_pretrained('models/dpr-nq-reproduced/ctx-encoder')\n",
"q_encoder.save_pretrained('models/dpr-nq-reproduced/q-encoder')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "88818375-d83a-4d17-8014-3b8c8bf95a25",
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"('models/dpr-nq-reproduced/q-encoder/tokenizer_config.json',\n",
" 'models/dpr-nq-reproduced/q-encoder/special_tokens_map.json',\n",
" 'models/dpr-nq-reproduced/q-encoder/vocab.txt',\n",
" 'models/dpr-nq-reproduced/q-encoder/added_tokens.json',\n",
" 'models/dpr-nq-reproduced/q-encoder/tokenizer.json')"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ctx_tokenizer.save_pretrained('models/dpr-nq-reproduced/ctx-encoder')\n",
"q_tokenizer.save_pretrained('models/dpr-nq-reproduced/q-encoder')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment