-
-
Save jamescalam/64e38a2a8e84db61e5739f9fe41c12f2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sentence_transformers import (\n", | |
" InputExample, SentenceTransformer\n", | |
")\n", | |
"from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n", | |
"from datasets import load_dataset\n", | |
"import json\n", | |
"from tqdm.auto import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Reusing dataset glue (C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\mrpc\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n", | |
"Reusing dataset glue (C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\qqp\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n", | |
"Reusing dataset glue (C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\stsb\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n", | |
"Reusing dataset glue (C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\rte\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n", | |
"Loading cached processed dataset at C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\rte\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\\cache-c700890cca65ec1d.arrow\n", | |
"100%|██████████| 295/295 [00:00<00:00, 595724.45it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"data = {}\n", | |
"data['mrpc'] = load_dataset('glue', 'mrpc', split='validation')\n", | |
"\n", | |
"data['qqp'] = load_dataset('glue', 'qqp', split='validation')\n", | |
"data['qqp'] = data['qqp'].rename_columns({\n", | |
" 'question1': 'sentence1',\n", | |
" 'question2': 'sentence2'\n", | |
"})\n", | |
"\n", | |
"data['stsb'] = load_dataset('glue', 'stsb', split='validation')\n", | |
"\n", | |
"data['rte'] = load_dataset('glue', 'rte', split='validation')\n", | |
"data['rte'] = data['rte'].map(lambda x:\n", | |
" {'label': 0 if bool(x['label']) else 1}\n", | |
")\n", | |
"\n", | |
"with open('data/med_qp_dev.json', 'r') as fp:\n", | |
" med_json = json.load(fp)\n", | |
"data['medqp'] = []\n", | |
"for row in tqdm(med_json['data']):\n", | |
" data['medqp'].append({\n", | |
" 'sentence1': row['question_1'],\n", | |
" 'sentence2': row['question_2'],\n", | |
" 'label': row['label']\n", | |
" })" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mrpc\n", | |
"qqp\n", | |
"stsb\n", | |
"rte\n", | |
"medqp\n" | |
] | |
} | |
], | |
"source": [ | |
"for domain in data.keys():\n", | |
" print(domain)\n", | |
" data[domain] = [\n", | |
" InputExample(\n", | |
" texts=[row['sentence1'], row['sentence2']],\n", | |
" label=row['label']\n", | |
" ) for row in data[domain]\n", | |
" ]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from pathlib import Path\n", | |
"\n", | |
"model_paths = Path('./').glob('bert-S*')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"*(model names are `bert-S{source domain}-T{target domain}`)*" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mrpc / bert-Smedqp_Tmrpc\n", | |
"qqp / bert-Smedqp_Tqqp\n", | |
"rte / bert-Smedqp_Trte\n", | |
"stsb / bert-Smedqp_Tstsb\n", | |
"medqp / bert-Smrpc_Tmedqp\n", | |
"qqp / bert-Smrpc_Tqqp\n", | |
"rte / bert-Smrpc_Trte\n", | |
"stsb / bert-Smrpc_Tstsb\n", | |
"medqp / bert-Sqqp_Tmedqp\n", | |
"mrpc / bert-Sqqp_Tmrpc\n", | |
"rte / bert-Sqqp_Trte\n", | |
"stsb / bert-Sqqp_Tstsb\n", | |
"medqp / bert-Srte_Tmedqp\n", | |
"mrpc / bert-Srte_Tmrpc\n", | |
"qqp / bert-Srte_Tqqp\n", | |
"stsb / bert-Srte_Tstsb\n", | |
"medqp / bert-Sstsb_Tmedqp\n", | |
"mrpc / bert-Sstsb_Tmrpc\n", | |
"qqp / bert-Sstsb_Tqqp\n", | |
"rte / bert-Sstsb_Trte\n" | |
] | |
} | |
], | |
"source": [ | |
"import pandas as pd\n", | |
"perf = pd.DataFrame({\n", | |
" 'model': [],\n", | |
" 'target': [],\n", | |
" 'score': []\n", | |
"})\n", | |
"\n", | |
"for path in model_paths:\n", | |
" path = str(path)\n", | |
" for domain in data.keys():\n", | |
" if 'T'+domain not in path: continue\n", | |
" print(f\"{domain} / {path}\")\n", | |
" model = SentenceTransformer(path)\n", | |
" evaluator = EmbeddingSimilarityEvaluator.from_input_examples(\n", | |
" data[domain], write_csv=False\n", | |
" )\n", | |
" perf = perf.append({\n", | |
" 'model': path,\n", | |
" 'target': domain,\n", | |
" 'score': round(evaluator(model), 3)\n", | |
" }, ignore_index=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:root:No sentence-transformers model found with name C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased. Creating a new one with MEAN pooling.\n", | |
"Some weights of the model checkpoint at C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']\n", | |
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | |
"WARNING:root:No sentence-transformers model found with name C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased. Creating a new one with MEAN pooling.\n", | |
"Some weights of the model checkpoint at C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']\n", | |
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | |
"WARNING:root:No sentence-transformers model found with name C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased. Creating a new one with MEAN pooling.\n", | |
"Some weights of the model checkpoint at C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']\n", | |
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | |
"WARNING:root:No sentence-transformers model found with name C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased. Creating a new one with MEAN pooling.\n", | |
"Some weights of the model checkpoint at C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']\n", | |
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | |
"WARNING:root:No sentence-transformers model found with name C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased. Creating a new one with MEAN pooling.\n", | |
"Some weights of the model checkpoint at C:\\Users\\James/.cache\\torch\\sentence_transformers\\bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']\n", | |
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" | |
] | |
} | |
], | |
"source": [ | |
"for domain in data.keys():\n", | |
" model = SentenceTransformer('bert-base-uncased')\n", | |
" evaluator = EmbeddingSimilarityEvaluator.from_input_examples(\n", | |
" data[domain], write_csv=False\n", | |
" )\n", | |
" perf = perf.append({\n", | |
" 'model': 'bert-base-uncased',\n", | |
" 'target': domain,\n", | |
" 'score': round(evaluator(model), 3)\n", | |
" }, ignore_index=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>model</th>\n", | |
" <th>target</th>\n", | |
" <th>score</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>bert-Smedqp_Tmrpc</td>\n", | |
" <td>mrpc</td>\n", | |
" <td>0.468</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>bert-Smedqp_Tqqp</td>\n", | |
" <td>qqp</td>\n", | |
" <td>0.492</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>bert-Smedqp_Trte</td>\n", | |
" <td>rte</td>\n", | |
" <td>0.057</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>bert-Smedqp_Tstsb</td>\n", | |
" <td>stsb</td>\n", | |
" <td>0.635</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>bert-Smrpc_Tmedqp</td>\n", | |
" <td>medqp</td>\n", | |
" <td>0.495</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>bert-Smrpc_Tqqp</td>\n", | |
" <td>qqp</td>\n", | |
" <td>0.440</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>bert-Smrpc_Trte</td>\n", | |
" <td>rte</td>\n", | |
" <td>0.066</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>bert-Smrpc_Tstsb</td>\n", | |
" <td>stsb</td>\n", | |
" <td>0.759</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>bert-Sqqp_Tmedqp</td>\n", | |
" <td>medqp</td>\n", | |
" <td>0.484</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>bert-Sqqp_Tmrpc</td>\n", | |
" <td>mrpc</td>\n", | |
" <td>0.371</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>10</th>\n", | |
" <td>bert-Sqqp_Trte</td>\n", | |
" <td>rte</td>\n", | |
" <td>0.008</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>11</th>\n", | |
" <td>bert-Sqqp_Tstsb</td>\n", | |
" <td>stsb</td>\n", | |
" <td>0.762</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>12</th>\n", | |
" <td>bert-Srte_Tmedqp</td>\n", | |
" <td>medqp</td>\n", | |
" <td>0.519</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>13</th>\n", | |
" <td>bert-Srte_Tmrpc</td>\n", | |
" <td>mrpc</td>\n", | |
" <td>0.403</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>14</th>\n", | |
" <td>bert-Srte_Tqqp</td>\n", | |
" <td>qqp</td>\n", | |
" <td>0.488</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>15</th>\n", | |
" <td>bert-Srte_Tstsb</td>\n", | |
" <td>stsb</td>\n", | |
" <td>0.622</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>16</th>\n", | |
" <td>bert-Sstsb_Tmedqp</td>\n", | |
" <td>medqp</td>\n", | |
" <td>0.553</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>17</th>\n", | |
" <td>bert-Sstsb_Tmrpc</td>\n", | |
" <td>mrpc</td>\n", | |
" <td>0.507</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>18</th>\n", | |
" <td>bert-Sstsb_Tqqp</td>\n", | |
" <td>qqp</td>\n", | |
" <td>0.543</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>19</th>\n", | |
" <td>bert-Sstsb_Trte</td>\n", | |
" <td>rte</td>\n", | |
" <td>0.154</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>20</th>\n", | |
" <td>bert-base-uncased</td>\n", | |
" <td>mrpc</td>\n", | |
" <td>0.388</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>21</th>\n", | |
" <td>bert-base-uncased</td>\n", | |
" <td>qqp</td>\n", | |
" <td>0.411</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>22</th>\n", | |
" <td>bert-base-uncased</td>\n", | |
" <td>stsb</td>\n", | |
" <td>0.615</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>23</th>\n", | |
" <td>bert-base-uncased</td>\n", | |
" <td>rte</td>\n", | |
" <td>0.086</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>24</th>\n", | |
" <td>bert-base-uncased</td>\n", | |
" <td>medqp</td>\n", | |
" <td>0.506</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" model target score\n", | |
"0 bert-Smedqp_Tmrpc mrpc 0.468\n", | |
"1 bert-Smedqp_Tqqp qqp 0.492\n", | |
"2 bert-Smedqp_Trte rte 0.057\n", | |
"3 bert-Smedqp_Tstsb stsb 0.635\n", | |
"4 bert-Smrpc_Tmedqp medqp 0.495\n", | |
"5 bert-Smrpc_Tqqp qqp 0.440\n", | |
"6 bert-Smrpc_Trte rte 0.066\n", | |
"7 bert-Smrpc_Tstsb stsb 0.759\n", | |
"8 bert-Sqqp_Tmedqp medqp 0.484\n", | |
"9 bert-Sqqp_Tmrpc mrpc 0.371\n", | |
"10 bert-Sqqp_Trte rte 0.008\n", | |
"11 bert-Sqqp_Tstsb stsb 0.762\n", | |
"12 bert-Srte_Tmedqp medqp 0.519\n", | |
"13 bert-Srte_Tmrpc mrpc 0.403\n", | |
"14 bert-Srte_Tqqp qqp 0.488\n", | |
"15 bert-Srte_Tstsb stsb 0.622\n", | |
"16 bert-Sstsb_Tmedqp medqp 0.553\n", | |
"17 bert-Sstsb_Tmrpc mrpc 0.507\n", | |
"18 bert-Sstsb_Tqqp qqp 0.543\n", | |
"19 bert-Sstsb_Trte rte 0.154\n", | |
"20 bert-base-uncased mrpc 0.388\n", | |
"21 bert-base-uncased qqp 0.411\n", | |
"22 bert-base-uncased stsb 0.615\n", | |
"23 bert-base-uncased rte 0.086\n", | |
"24 bert-base-uncased medqp 0.506" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"perf" | |
] | |
} | |
], | |
"metadata": { | |
"interpreter": { | |
"hash": "2ada91ca7be38ac141a70d8e06f4253d3e90604f2701bfa98443d880c4baa087" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3.8.8 64-bit ('search': conda)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.8" | |
}, | |
"orig_nbformat": 4 | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment