Created October 8, 2022 04:57
"import os\n",
"import json\n",
"import transformers as hft\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from tqdm import tqdm"
"class NQDataset(\n",
" def __init__(self, data: list):\n",
" = data\n",
" \n",
" def __len__(self):\n",
" return len(\n",
" \n",
" def __getitem__(self, idx):\n",
" return (\n",
" )\n",
"def get_schedule_linear(\n",
" optimizer,\n",
" warmup_steps,\n",
" total_training_steps,\n",
" steps_shift=0,\n",
" last_epoch=-1,\n",
" \"\"\"\n",
" Create a schedule with a learning rate that decreases linearly after\n",
" linearly increasing during a warmup period.\n",
" \n",
" Source:\n",
" \"\"\"\n",
" def lr_lambda(current_step):\n",
" current_step += steps_shift\n",
" if current_step < warmup_steps:\n",
" return float(current_step) / float(max(1, warmup_steps))\n",
" return max(\n",
" 1e-7,\n",
" float(total_training_steps - current_step) / float(max(1, total_training_steps - warmup_steps)),\n",
" )\n",
" return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)\n",
"def criterion(S, target=None):\n",
" softS = F.log_softmax(S, dim=1)\n",
" if target is None:\n",
" target = torch.arange(0, S.shape[0])\n",
" target =\n",
" loss = F.nll_loss(softS, target, reduction=\"mean\")\n",
" return loss"
"batch_size = 128\n",
"warmup_steps = 1237\n",
"num_epochs = 40\n",
"learning_rate = 2e-5 # 2e-5 for NQ, 1e-5 for other datasets\n",
"max_length = 256\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"q_tokenizer = hft.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
"ctx_tokenizer = hft.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
"q_encoder = hft.AutoModel.from_pretrained(\"bert-base-uncased\").to(device)\n",
"ctx_encoder = hft.AutoModel.from_pretrained(\"bert-base-uncased\").to(device)"
"train_json = json.load(open('data/biencoder-nq-train.json'))\n",
"valid_json = json.load(open('data/biencoder-nq-dev.json'))"
"train_loader =\n",
" NQDataset(train_json), batch_size=batch_size, shuffle=True\n",
"optimizer = torch.optim.AdamW(\n",
" [{\"params\": ctx_encoder.parameters()}, {\"params\": q_encoder.parameters()}],\n",
" lr=learning_rate, \n",
" weight_decay=0.0,\n",
"total_training_steps = num_epochs * len(train_loader)\n",
"scheduler = get_schedule_linear(\n",
" optimizer, warmup_steps=warmup_steps, total_training_steps=total_training_steps\n",
"Epoch #0: 100%|██████████| 460/460 [07:38<00:00, 1.00it/s, loss=4.66]\n",
"Epoch #1: 100%|██████████| 460/460 [07:38<00:00, 1.00it/s, loss=2.72]\n",
"Epoch #2: 100%|██████████| 460/460 [07:35<00:00, 1.01it/s, loss=0.653]\n",
"Epoch #3: 100%|██████████| 460/460 [07:36<00:00, 1.01it/s, loss=0.406]\n",
"Epoch #4: 100%|██████████| 460/460 [07:35<00:00, 1.01it/s, loss=0.381] \n",
"Epoch #5: 100%|██████████| 460/460 [07:35<00:00, 1.01it/s, loss=0.242] \n",
"Epoch #6: 100%|██████████| 460/460 [07:32<00:00, 1.02it/s, loss=0.224] \n",
"Epoch #7: 100%|██████████| 460/460 [07:34<00:00, 1.01it/s, loss=0.0726]\n",
"Epoch #8: 100%|██████████| 460/460 [07:34<00:00, 1.01it/s, loss=0.0708]\n",
"Epoch #9: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.11] \n",
"Epoch #10: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.0283] \n",
"Epoch #11: 100%|██████████| 460/460 [07:32<00:00, 1.02it/s, loss=0.0475] \n",
"Epoch #12: 100%|██████████| 460/460 [07:34<00:00, 1.01it/s, loss=0.0762] \n",
"Epoch #25: 100%|██████████| 460/460 [07:34<00:00, 1.01it/s, loss=0.0115] \n",
"Epoch #26: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.00752] \n",
"Epoch #27: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.0207] \n",
"Epoch #39: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.000152]\n"
"results = {}\n",
"for epoch_num in range(num_epochs):\n",
" results[epoch_num] = []\n",
" \n",
" ctx_encoder.train()\n",
" q_encoder.train()\n",
" pbar = tqdm(train_loader, desc=f'Epoch #{epoch_num}')\n",
" \n",
" for q, t, c in pbar:\n",
" tokenizer_kwargs = dict(max_length=max_length, return_tensors=\"pt\", truncation=True, padding=True)\n",
" queries = q_tokenizer(list(q), **tokenizer_kwargs).to(device)\n",
" passages = ctx_tokenizer(list(t), list(c), **tokenizer_kwargs).to(device)\n",
" optimizer.zero_grad()\n",
" Q = q_encoder(**queries).pooler_output\n",
" P = ctx_encoder(**passages)\n",
" S =, P.T)\n",
" loss = criterion(S)\n",
" loss.backward()\n",
" \n",
" optimizer.step()\n",
" scheduler.step()\n",
" pbar.set_postfix({'loss': loss.item()})\n",
" \n",
" results[epoch_num].append(loss.item())"
