-
-
Save jamescalam/2dbc9874b599dde95d8ddcdd018dfcf6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Datasets\n", | |
"\n", | |
"We will be testing AugSBERT domain transfer on **five** datasets. Those are:\n", | |
"\n", | |
"| Dataset | Description |\n", | |
"| --- | --- |\n", | |
"| STSb | Semantic textual similarity benchmark data, simple sentence pairs alongside a similarity score (translated to continuous value from `0` -> `1`) |\n", | |
"| Medical question pairs | Medical question pairs given a label for similar `1` or dissimilar `0` |\n", | |
"| Quora-QP | Quora question pairs given a label marking them as duplicates `1` or non-duplicates `0` |\n", | |
"| Microsoft Research Paraphrase Corpus (MRPC) | Sentence pairs collected from news articles, label marks pairs as equivalent `1` or not `0` |\n", | |
"| Recognizing Textual Entailment (RTE) | Sentence pairs are marked as entailing each other `0` or neutral `1` |\n", | |
"\n", | |
"We download each dataset as follows:\n", | |
"\n", | |
"## STSb" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from datasets import load_dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Reusing dataset glue (C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\stsb\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"Dataset({\n", | |
" features: ['sentence1', 'sentence2', 'label', 'idx'],\n", | |
" num_rows: 5749\n", | |
"})" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"stsb = load_dataset('glue', 'stsb', split='train')\n", | |
"stsb" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"---\n", | |
"\n", | |
"## Medical Question Pairs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Downloading: 2.83kB [00:00, 2.85MB/s] \n", | |
"Downloading: 1.22kB [00:00, 2.46MB/s] \n", | |
"Using custom data configuration default\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'label': 1, 'question_1': 'After how many hour from drinking an antibiotic can I drink alcohol?', 'dr_id': 1, 'question_2': 'I have a party tonight and I took my last dose of Azithromycin this morning. Can I have a few drinks?'}\n", | |
"{'label': 0, 'question_1': 'After how many hour from drinking an antibiotic can I drink alcohol?', 'dr_id': 1, 'question_2': 'I vomited this morning and I am not sure if it is the side effect of my antibiotic or the alcohol I took last night...'}\n" | |
] | |
} | |
], | |
"source": [ | |
"med_qp = load_dataset('medical_questions_pairs', streaming=True)\n", | |
"for i, row in enumerate(med_qp['train']):\n", | |
" if i == 2: break\n", | |
" print(row)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"There is no `validation` split for this dataset so we will create it now, and save both train and val sets locally." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from random import random" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train: 2753\n", | |
"dev: 295\n" | |
] | |
} | |
], | |
"source": [ | |
"to_save = {'train': [], 'dev': []}\n", | |
"\n", | |
"for row in med_qp['train']:\n", | |
" if random() > 0.9:\n", | |
" to_save['dev'].append(row)\n", | |
" else:\n", | |
" to_save['train'].append(row)\n", | |
"\n", | |
"for split in ['train', 'dev']:\n", | |
" print(f\"{split}: {len(to_save[split])}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import json\n", | |
"\n", | |
"if not os.path.isdir('data'):\n", | |
" os.mkdir('data')\n", | |
"\n", | |
"for split in ['train', 'dev']:\n", | |
" with open(f\"data/med_qp_{split}.json\", 'w') as fp:\n", | |
" json.dump({'data': to_save[split]}, fp)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Then we load with..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2753" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import json\n", | |
"\n", | |
"with open('data/med_qp_train.json', 'r') as fp:\n", | |
" med_json = json.load(fp)\n", | |
"medqp = []\n", | |
"for row in med_json['data']:\n", | |
" medqp.append({\n", | |
" 'sentence1': row['question_1'],\n", | |
" 'sentence2': row['question_2'],\n", | |
" 'label': row['label']\n", | |
" })\n", | |
"len(medqp)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"---\n", | |
"\n", | |
"## Quora-QP" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Reusing dataset glue (C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\qqp\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" | |
] | |
} | |
], | |
"source": [ | |
"qqp = load_dataset('glue', 'qqp', split='train')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We need to align feature names..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"qqp = qqp.rename_columns({\n", | |
" 'question1': 'sentence1',\n", | |
" 'question2': 'sentence2'\n", | |
" })" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Dataset({\n", | |
" features: ['sentence1', 'sentence2', 'label', 'idx'],\n", | |
" num_rows: 363846\n", | |
"})" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"qqp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'sentence1': 'How is the life of a math student? Could you describe your own experiences?',\n", | |
" 'sentence2': 'Which level of prepration is enough for the exam jlpt5?',\n", | |
" 'label': 0,\n", | |
" 'idx': 0}" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"qqp[0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"---\n", | |
"\n", | |
"## MRPC" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Reusing dataset glue (C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\mrpc\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"Dataset({\n", | |
" features: ['sentence1', 'sentence2', 'label', 'idx'],\n", | |
" num_rows: 3668\n", | |
"})" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mrpc = load_dataset('glue', 'mrpc', split='train')\n", | |
"mrpc" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'sentence1': 'Amrozi accused his brother , whom he called \" the witness \" , of deliberately distorting his evidence .',\n", | |
" 'sentence2': 'Referring to him as only \" the witness \" , Amrozi accused his brother of deliberately distorting his evidence .',\n", | |
" 'label': 1,\n", | |
" 'idx': 0}" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mrpc[0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"---\n", | |
"\n", | |
"## RTE" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Reusing dataset glue (C:\\Users\\James\\.cache\\huggingface\\datasets\\glue\\rte\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"Dataset({\n", | |
" features: ['sentence1', 'sentence2', 'label', 'idx'],\n", | |
" num_rows: 2490\n", | |
"})" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"rte = load_dataset('glue', 'rte', split='train')\n", | |
"rte" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The `label` for RTE is switched (`0` indicates similarity and `1` not, so we swap)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'sentence1': 'No Weapons of Mass Destruction Found in Iraq Yet.',\n", | |
" 'sentence2': 'Weapons of Mass Destruction Found in Iraq.',\n", | |
" 'label': 1,\n", | |
" 'idx': 0}" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"rte[0]" | |
] | |
} | |
], | |
"metadata": { | |
"interpreter": { | |
"hash": "5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3.8.8 64-bit ('ml': conda)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.8" | |
}, | |
"orig_nbformat": 4 | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment