Skip to content

Instantly share code, notes, and snippets.

@jamescalam
Created March 25, 2022 10:47
Show Gist options
  • Save jamescalam/15d968bed79b884bf50090a34093f508 to your computer and use it in GitHub Desktop.
Save jamescalam/15d968bed79b884bf50090a34093f508 to your computer and use it in GitHub Desktop.
04_finetune.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/jamescalam/15d968bed79b884bf50090a34093f508/04_finetune.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VE-w5Hy7r1PY"
},
"source": [
"# Fine-tuning with MSEMargin Loss\n",
"\n",
"Now that we have our margin labeled *(Q, P<sup>+</sup>, P<sup>-</sup>)* pairs, we can begin fine-tuning a bi-encoder model with MSEMargin loss. We will start by defining a data loading function that uses the standard `InputExample` format of *sentence-transformers*."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M3ZDhGdHr1PZ",
"outputId": "217f4dec-c1fa-4888-efac-9e8ddbaf3978",
"colab": {
"referenced_widgets": [
"a4549bb0e64f495b91db5f6eea7a17f6"
]
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a4549bb0e64f495b91db5f6eea7a17f6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/200000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"200000"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from tqdm.auto import tqdm\n",
"from sentence_transformers import InputExample\n",
"\n",
"training_data = []\n",
"\n",
"with open('data/triplets_margin.tsv', 'r', encoding='utf-8') as fp:\n",
" lines = fp.read().split('\\n')\n",
"# loop through each line and return InputExample\n",
"for line in tqdm(lines):\n",
" q, p, n, margin = line.split('\\t')\n",
" training_data.append(InputExample(\n",
" texts=[q, p, n],\n",
" label=float(margin)\n",
" ))\n",
"\n",
"len(training_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mcnQLX8Gr1Pa"
},
"source": [
"Let's see the contents of one `InputExample` object."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ssBtBww1r1Pb",
"outputId": "2f1a64ca-67fe-4963-d45d-999ade6df98c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Query: why does prediction performance decline with large data points\n",
"Passage +: 2) Discussion: Lines 392-396: \"Second, it seems that the prediction performance drops with the incre\n",
"Passage -: \" 2) Discussion: Lines 392-396: \"Second, it seems that the prediction performance drops with the inc\n",
"Margin: 0.28338432\n",
"\n"
]
}
],
"source": [
"print(f\"\"\"\n",
"Query: {training_data[0].texts[0]}\n",
"Passage +: {training_data[0].texts[1][:100]}\n",
"Passage -: {training_data[0].texts[2][:100]}\n",
"Margin: {training_data[0].label}\n",
"\"\"\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NCEjg9H6r1Pb"
},
"source": [
"We load these pairs into a generator `DataLoader`. Margin MSE works best with a large `batch_size`, the `32` used here is reasonable."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tFfA3Sjbr1Pb"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"torch.cuda.empty_cache()\n",
"\n",
"batch_size = 32\n",
"\n",
"loader = torch.utils.data.DataLoader(\n",
" training_data, batch_size=batch_size, shuffle=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z9ZMpOAcr1Pc"
},
"source": [
"Next we initialize a bi-encoder model that we will be fine-tuning using domain adaption."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4GTcK7J8r1Pc",
"outputId": "e2e7b65a-cfef-4bc4-f103-a957b227c7d0"
},
"outputs": [
{
"data": {
"text/plain": [
"SentenceTransformer(\n",
" (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: DistilBertModel \n",
" (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
")"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sentence_transformers import SentenceTransformer\n",
"\n",
"model = SentenceTransformer('msmarco-distilbert-base-tas-b')\n",
"model.max_seq_length = 256\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TweZBa8qr1Pd"
},
"source": [
"Then initialize the Margin MSE loss function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LGMMtL37r1Pd"
},
"outputs": [],
"source": [
"from sentence_transformers import losses\n",
"\n",
"loss = losses.MarginMSELoss(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kV3Xt-9lr1Pd",
"outputId": "44c7199c-a429-4c24-bc85-ac4bfddc1bc9",
"colab": {
"referenced_widgets": [
"2055a68e819c4a2295d62029458f7c7e",
"9e6a24e3909645e299b2e5ec192b2ef8"
]
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.7/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" FutureWarning,\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2055a68e819c4a2295d62029458f7c7e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch: 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9e6a24e3909645e299b2e5ec192b2ef8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Iteration: 0%| | 0/6250 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"epochs = 1\n",
"warmup_steps = int(len(loader) * epochs * 0.1)\n",
"\n",
"model.fit(\n",
" train_objectives=[(loader, loss)],\n",
" epochs=epochs,\n",
" warmup_steps=warmup_steps,\n",
" output_path='msmarco-distilbert-base-tas-b-covid',\n",
" show_progress_bar=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5cd3lQM3r1Pd"
},
"source": [
"The model is saved in the `msmarco-distilbert-base-tas-b-covid` directory."
]
}
],
"metadata": {
"environment": {
"kernel": "conda-root-py",
"name": "common-cu110.m91",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
},
"interpreter": {
"hash": "5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408"
},
"kernelspec": {
"display_name": "Python [conda env:root] *",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"colab": {
"name": "04_finetune.ipynb",
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment