Skip to content

Instantly share code, notes, and snippets.

@jamescalam
Created December 26, 2021 07:20
Show Gist options
  • Save jamescalam/64e38a2a8e84db61e5739f9fe41c12f2 to your computer and use it in GitHub Desktop.
Save jamescalam/64e38a2a8e84db61e5739f9fe41c12f2 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from sentence_transformers import (\n",
" InputExample, SentenceTransformer\n",
")\n",
"from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n",
"from datasets import load_dataset\n",
"import json\n",
"from tqdm.auto import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset glue (C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\mrpc\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
"Reusing dataset glue (C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\qqp\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
"Reusing dataset glue (C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\stsb\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
"Reusing dataset glue (C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\rte\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
"Loading cached processed dataset at C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\rte\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\\cache-c700890cca65ec1d.arrow\n",
"100%|██████████| 295/295 [00:00<00:00, 595724.45it/s]\n"
]
}
],
"source": [
"data = {}\n",
"data['mrpc'] = load_dataset('glue', 'mrpc', split='validation')\n",
"\n",
"data['qqp'] = load_dataset('glue', 'qqp', split='validation')\n",
"data['qqp'] = data['qqp'].rename_columns({\n",
" 'question1': 'sentence1',\n",
" 'question2': 'sentence2'\n",
"})\n",
"\n",
"data['stsb'] = load_dataset('glue', 'stsb', split='validation')\n",
"\n",
"data['rte'] = load_dataset('glue', 'rte', split='validation')\n",
"data['rte'] = data['rte'].map(lambda x:\n",
" {'label': 0 if bool(x['label']) else 1}\n",
")\n",
"\n",
"with open('data/med_qp_dev.json', 'r') as fp:\n",
" med_json = json.load(fp)\n",
"data['medqp'] = []\n",
"for row in tqdm(med_json['data']):\n",
" data['medqp'].append({\n",
" 'sentence1': row['question_1'],\n",
" 'sentence2': row['question_2'],\n",
" 'label': row['label']\n",
" })"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mrpc\n",
"qqp\n",
"stsb\n",
"rte\n",
"medqp\n"
]
}
],
"source": [
"for domain in data.keys():\n",
" print(domain)\n",
" data[domain] = [\n",
" InputExample(\n",
" texts=[row['sentence1'], row['sentence2']],\n",
" label=row['label']\n",
" ) for row in data[domain]\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"model_paths = Path('./').glob('bert-S*')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*(model names are `bert-S{source domain}-T{target domain}`)*"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mrpc / bert-Smedqp_Tmrpc\n",
"qqp / bert-Smedqp_Tqqp\n",
"rte / bert-Smedqp_Trte\n",
"stsb / bert-Smedqp_Tstsb\n",
"medqp / bert-Smrpc_Tmedqp\n",
"qqp / bert-Smrpc_Tqqp\n",
"rte / bert-Smrpc_Trte\n",
"stsb / bert-Smrpc_Tstsb\n",
"medqp / bert-Sqqp_Tmedqp\n",
"mrpc / bert-Sqqp_Tmrpc\n",
"rte / bert-Sqqp_Trte\n",
"stsb / bert-Sqqp_Tstsb\n",
"medqp / bert-Srte_Tmedqp\n",
"mrpc / bert-Srte_Tmrpc\n",
"qqp / bert-Srte_Tqqp\n",
"stsb / bert-Srte_Tstsb\n",
"medqp / bert-Sstsb_Tmedqp\n",
"mrpc / bert-Sstsb_Tmrpc\n",
"qqp / bert-Sstsb_Tqqp\n",
"rte / bert-Sstsb_Trte\n"
]
}
],
"source": [
"import pandas as pd\n",
"perf = pd.DataFrame({\n",
" 'model': [],\n",
" 'target': [],\n",
" 'score': []\n",
"})\n",
"\n",
"for path in model_paths:\n",
" path = str(path)\n",
" for domain in data.keys():\n",
" if 'T'+domain not in path: continue\n",
" print(f\"{domain} / {path}\")\n",
" model = SentenceTransformer(path)\n",
" evaluator = EmbeddingSimilarityEvaluator.from_input_examples(\n",
" data[domain], write_csv=False\n",
" )\n",
" perf = perf.append({\n",
" 'model': path,\n",
" 'target': domain,\n",
" 'score': round(evaluator(model), 3)\n",
" }, ignore_index=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:root:No sentence-transformers model found with name C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased. Creating a new one with MEAN pooling.\n",
"Some weights of the model checkpoint at C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"WARNING:root:No sentence-transformers model found with name C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased. Creating a new one with MEAN pooling.\n",
"Some weights of the model checkpoint at C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"WARNING:root:No sentence-transformers model found with name C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased. Creating a new one with MEAN pooling.\n",
"Some weights of the model checkpoint at C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"WARNING:root:No sentence-transformers model found with name C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased. Creating a new one with MEAN pooling.\n",
"Some weights of the model checkpoint at C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"WARNING:root:No sentence-transformers model found with name C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased. Creating a new one with MEAN pooling.\n",
"Some weights of the model checkpoint at C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"for domain in data.keys():\n",
" model = SentenceTransformer('bert-base-uncased')\n",
" evaluator = EmbeddingSimilarityEvaluator.from_input_examples(\n",
" data[domain], write_csv=False\n",
" )\n",
" perf = perf.append({\n",
" 'model': 'bert-base-uncased',\n",
" 'target': domain,\n",
" 'score': round(evaluator(model), 3)\n",
" }, ignore_index=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>model</th>\n",
" <th>target</th>\n",
" <th>score</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>bert-Smedqp_Tmrpc</td>\n",
" <td>mrpc</td>\n",
" <td>0.468</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>bert-Smedqp_Tqqp</td>\n",
" <td>qqp</td>\n",
" <td>0.492</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>bert-Smedqp_Trte</td>\n",
" <td>rte</td>\n",
" <td>0.057</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>bert-Smedqp_Tstsb</td>\n",
" <td>stsb</td>\n",
" <td>0.635</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>bert-Smrpc_Tmedqp</td>\n",
" <td>medqp</td>\n",
" <td>0.495</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>bert-Smrpc_Tqqp</td>\n",
" <td>qqp</td>\n",
" <td>0.440</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>bert-Smrpc_Trte</td>\n",
" <td>rte</td>\n",
" <td>0.066</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>bert-Smrpc_Tstsb</td>\n",
" <td>stsb</td>\n",
" <td>0.759</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>bert-Sqqp_Tmedqp</td>\n",
" <td>medqp</td>\n",
" <td>0.484</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>bert-Sqqp_Tmrpc</td>\n",
" <td>mrpc</td>\n",
" <td>0.371</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>bert-Sqqp_Trte</td>\n",
" <td>rte</td>\n",
" <td>0.008</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>bert-Sqqp_Tstsb</td>\n",
" <td>stsb</td>\n",
" <td>0.762</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>bert-Srte_Tmedqp</td>\n",
" <td>medqp</td>\n",
" <td>0.519</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>bert-Srte_Tmrpc</td>\n",
" <td>mrpc</td>\n",
" <td>0.403</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>bert-Srte_Tqqp</td>\n",
" <td>qqp</td>\n",
" <td>0.488</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>bert-Srte_Tstsb</td>\n",
" <td>stsb</td>\n",
" <td>0.622</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>bert-Sstsb_Tmedqp</td>\n",
" <td>medqp</td>\n",
" <td>0.553</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>bert-Sstsb_Tmrpc</td>\n",
" <td>mrpc</td>\n",
" <td>0.507</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>bert-Sstsb_Tqqp</td>\n",
" <td>qqp</td>\n",
" <td>0.543</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>bert-Sstsb_Trte</td>\n",
" <td>rte</td>\n",
" <td>0.154</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>bert-base-uncased</td>\n",
" <td>mrpc</td>\n",
" <td>0.388</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>bert-base-uncased</td>\n",
" <td>qqp</td>\n",
" <td>0.411</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>bert-base-uncased</td>\n",
" <td>stsb</td>\n",
" <td>0.615</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>bert-base-uncased</td>\n",
" <td>rte</td>\n",
" <td>0.086</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <td>bert-base-uncased</td>\n",
" <td>medqp</td>\n",
" <td>0.506</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" model target score\n",
"0 bert-Smedqp_Tmrpc mrpc 0.468\n",
"1 bert-Smedqp_Tqqp qqp 0.492\n",
"2 bert-Smedqp_Trte rte 0.057\n",
"3 bert-Smedqp_Tstsb stsb 0.635\n",
"4 bert-Smrpc_Tmedqp medqp 0.495\n",
"5 bert-Smrpc_Tqqp qqp 0.440\n",
"6 bert-Smrpc_Trte rte 0.066\n",
"7 bert-Smrpc_Tstsb stsb 0.759\n",
"8 bert-Sqqp_Tmedqp medqp 0.484\n",
"9 bert-Sqqp_Tmrpc mrpc 0.371\n",
"10 bert-Sqqp_Trte rte 0.008\n",
"11 bert-Sqqp_Tstsb stsb 0.762\n",
"12 bert-Srte_Tmedqp medqp 0.519\n",
"13 bert-Srte_Tmrpc mrpc 0.403\n",
"14 bert-Srte_Tqqp qqp 0.488\n",
"15 bert-Srte_Tstsb stsb 0.622\n",
"16 bert-Sstsb_Tmedqp medqp 0.553\n",
"17 bert-Sstsb_Tmrpc mrpc 0.507\n",
"18 bert-Sstsb_Tqqp qqp 0.543\n",
"19 bert-Sstsb_Trte rte 0.154\n",
"20 bert-base-uncased mrpc 0.388\n",
"21 bert-base-uncased qqp 0.411\n",
"22 bert-base-uncased stsb 0.615\n",
"23 bert-base-uncased rte 0.086\n",
"24 bert-base-uncased medqp 0.506"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"perf"
]
}
],
"metadata": {
"interpreter": {
"hash": "2ada91ca7be38ac141a70d8e06f4253d3e90604f2701bfa98443d880c4baa087"
},
"kernelspec": {
"display_name": "Python 3.8.8 64-bit ('search': conda)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment