-
-
Save jamescalam/15d968bed79b884bf50090a34093f508 to your computer and use it in GitHub Desktop.
04_finetune.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/jamescalam/15d968bed79b884bf50090a34093f508/04_finetune.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "VE-w5Hy7r1PY" | |
}, | |
"source": [ | |
"# Fine-tuning with MSEMargin Loss\n", | |
"\n", | |
"Now that we have our margin labeled *(Q, P<sup>+</sup>, P<sup>-</sup>)* pairs, we can begin fine-tuning a bi-encoder model with MSEMargin loss. We will start by defining a data loading function that uses the standard `InputExample` format of *sentence-transformers*." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "M3ZDhGdHr1PZ", | |
"outputId": "217f4dec-c1fa-4888-efac-9e8ddbaf3978", | |
"colab": { | |
"referenced_widgets": [ | |
"a4549bb0e64f495b91db5f6eea7a17f6" | |
] | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "a4549bb0e64f495b91db5f6eea7a17f6", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/200000 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"200000" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from tqdm.auto import tqdm\n", | |
"from sentence_transformers import InputExample\n", | |
"\n", | |
"training_data = []\n", | |
"\n", | |
"with open('data/triplets_margin.tsv', 'r', encoding='utf-8') as fp:\n", | |
" lines = fp.read().split('\\n')\n", | |
"# loop through each line and return InputExample\n", | |
"for line in tqdm(lines):\n", | |
" q, p, n, margin = line.split('\\t')\n", | |
" training_data.append(InputExample(\n", | |
" texts=[q, p, n],\n", | |
" label=float(margin)\n", | |
" ))\n", | |
"\n", | |
"len(training_data)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "mcnQLX8Gr1Pa" | |
}, | |
"source": [ | |
"Let's see the contents of one `InputExample` object." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "ssBtBww1r1Pb", | |
"outputId": "2f1a64ca-67fe-4963-d45d-999ade6df98c" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"Query: why does prediction performance decline with large data points\n", | |
"Passage +: 2) Discussion: Lines 392-396: \"Second, it seems that the prediction performance drops with the incre\n", | |
"Passage -: \" 2) Discussion: Lines 392-396: \"Second, it seems that the prediction performance drops with the inc\n", | |
"Margin: 0.28338432\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"print(f\"\"\"\n", | |
"Query: {training_data[0].texts[0]}\n", | |
"Passage +: {training_data[0].texts[1][:100]}\n", | |
"Passage -: {training_data[0].texts[2][:100]}\n", | |
"Margin: {training_data[0].label}\n", | |
"\"\"\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "NCEjg9H6r1Pb" | |
}, | |
"source": [ | |
"We load these pairs into a generator `DataLoader`. Margin MSE works best with a large `batch_size`, the `32` used here is reasonable." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "tFfA3Sjbr1Pb" | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"\n", | |
"torch.cuda.empty_cache()\n", | |
"\n", | |
"batch_size = 32\n", | |
"\n", | |
"loader = torch.utils.data.DataLoader(\n", | |
" training_data, batch_size=batch_size, shuffle=True\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Z9ZMpOAcr1Pc" | |
}, | |
"source": [ | |
"Next we initialize a bi-encoder model that we will be fine-tuning using domain adaption." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "4GTcK7J8r1Pc", | |
"outputId": "e2e7b65a-cfef-4bc4-f103-a957b227c7d0" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"SentenceTransformer(\n", | |
" (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: DistilBertModel \n", | |
" (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n", | |
")" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from sentence_transformers import SentenceTransformer\n", | |
"\n", | |
"model = SentenceTransformer('msmarco-distilbert-base-tas-b')\n", | |
"model.max_seq_length = 256\n", | |
"model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "TweZBa8qr1Pd" | |
}, | |
"source": [ | |
"Then initialize the Margin MSE loss function." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "LGMMtL37r1Pd" | |
}, | |
"outputs": [], | |
"source": [ | |
"from sentence_transformers import losses\n", | |
"\n", | |
"loss = losses.MarginMSELoss(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "kV3Xt-9lr1Pd", | |
"outputId": "44c7199c-a429-4c24-bc85-ac4bfddc1bc9", | |
"colab": { | |
"referenced_widgets": [ | |
"2055a68e819c4a2295d62029458f7c7e", | |
"9e6a24e3909645e299b2e5ec192b2ef8" | |
] | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/conda/lib/python3.7/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", | |
" FutureWarning,\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "2055a68e819c4a2295d62029458f7c7e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch: 0%| | 0/1 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "9e6a24e3909645e299b2e5ec192b2ef8", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Iteration: 0%| | 0/6250 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"epochs = 1\n", | |
"warmup_steps = int(len(loader) * epochs * 0.1)\n", | |
"\n", | |
"model.fit(\n", | |
" train_objectives=[(loader, loss)],\n", | |
" epochs=epochs,\n", | |
" warmup_steps=warmup_steps,\n", | |
" output_path='msmarco-distilbert-base-tas-b-covid',\n", | |
" show_progress_bar=True\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "5cd3lQM3r1Pd" | |
}, | |
"source": [ | |
"The model is saved in the `msmarco-distilbert-base-tas-b-covid` directory." | |
] | |
} | |
], | |
"metadata": { | |
"environment": { | |
"kernel": "conda-root-py", | |
"name": "common-cu110.m91", | |
"type": "gcloud", | |
"uri": "gcr.io/deeplearning-platform-release/base-cu110:m91" | |
}, | |
"interpreter": { | |
"hash": "5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408" | |
}, | |
"kernelspec": { | |
"display_name": "Python [conda env:root] *", | |
"language": "python", | |
"name": "conda-root-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.12" | |
}, | |
"colab": { | |
"name": "04_finetune.ipynb", | |
"provenance": [], | |
"include_colab_link": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment