Skip to content

Instantly share code, notes, and snippets.

@mkeywood1
Last active March 29, 2024 12:16
Show Gist options
  • Save mkeywood1/9e8411aef44cf18009aa3e4776501c08 to your computer and use it in GitHub Desktop.
Save mkeywood1/9e8411aef44cf18009aa3e4776501c08 to your computer and use it in GitHub Desktop.
Jupyter notebook for fine tuning a T5 small model to generate SQL from natural language
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "21833f8d",
"metadata": {},
"source": [
"# Fine Tuning of a SQL Model\n",
"\n",
"### Inspired by https://huggingface.co/cssupport/t5-small-awesome-text-to-sql\n",
"\n",
"### Datasets:\n",
"- https://huggingface.co/datasets/b-mc2/sql-create-context\n",
"- https://huggingface.co/datasets/Clinton/Text-to-sql-v1\n",
"- https://huggingface.co/datasets/knowrohit07/know_sql"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1f78e14f",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from datasets import Dataset, DatasetDict, load_dataset, interleave_datasets, load_from_disk\n",
"from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer\n",
"import torch\n",
"import time\n",
"import evaluate\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cd00c140",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1971e6c5",
"metadata": {},
"outputs": [],
"source": [
"model_name='t5-small'\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
"original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n",
"original_model = original_model.to('cuda')"
]
},
{
"cell_type": "markdown",
"id": "f5e5419b",
"metadata": {},
"source": [
"# Load Datasets"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ee806dfd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded Merged Dataset\n"
]
},
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['question', 'context', 'answer'],\n",
" num_rows: 118695\n",
" })\n",
" test: Dataset({\n",
" features: ['question', 'context', 'answer'],\n",
" num_rows: 14835\n",
" })\n",
" validation: Dataset({\n",
" features: ['question', 'context', 'answer'],\n",
" num_rows: 14838\n",
" })\n",
"})"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"try:\n",
" dataset = load_from_disk(\"merged_dataset\")\n",
" print(\"Loaded Merged Dataset\")\n",
"except:\n",
" dataset_scc_train = load_dataset(\"b-mc2/sql-create-context\", split='train[:80%]')\n",
" dataset_scc_test = load_dataset(\"b-mc2/sql-create-context\", split='train[-20%:-10%]')\n",
" dataset_scc_val = load_dataset(\"b-mc2/sql-create-context\", split='train[-10%:]')\n",
"\n",
" dataset_tts_train = load_dataset(\"Clinton/Text-to-sql-v1\", split='train[:80%]')\n",
" dataset_tts_train = dataset_tts_train.remove_columns(['source', 'text'])\n",
" dataset_tts_train = dataset_tts_train.rename_columns({'instruction': 'question', 'input': 'context', 'response': 'answer'})\n",
" dataset_tts_test = load_dataset(\"Clinton/Text-to-sql-v1\", split='train[-20%:-10%]')\n",
" dataset_tts_test = dataset_tts_test.remove_columns(['source', 'text'])\n",
" dataset_tts_test = dataset_tts_test.rename_columns({'instruction': 'question', 'input': 'context', 'response': 'answer'})\n",
" dataset_tts_val = load_dataset(\"Clinton/Text-to-sql-v1\", split='train[-10%:]')\n",
" dataset_tts_val = dataset_tts_val.remove_columns(['source', 'text'])\n",
" dataset_tts_val = dataset_tts_val.rename_columns({'instruction': 'question', 'input': 'context', 'response': 'answer'})\n",
"\n",
" dataset_ks_train = load_dataset(\"knowrohit07/know_sql\", split='validation[:80%]')\n",
" dataset_ks_test = load_dataset(\"knowrohit07/know_sql\", split='validation[-20%:-10%]')\n",
" dataset_ks_val = load_dataset(\"knowrohit07/know_sql\", split='validation[-10%:]')\n",
"\n",
" dataset = DatasetDict({ 'train': interleave_datasets([dataset_scc_train, dataset_tts_train, dataset_ks_train]),\n",
" 'test': interleave_datasets([dataset_scc_test, dataset_tts_test, dataset_ks_test]),\n",
" 'validation': interleave_datasets([dataset_scc_val, dataset_tts_val, dataset_ks_val])})\n",
"\n",
" dataset.save_to_disk(\"merged_dataset\")\n",
" print(\"Merged and Saved Dataset\")\n",
"\n",
"dataset"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "89b95075",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'question': 'On what Date did the Away team essendon play?',\n",
" 'context': 'CREATE TABLE table_name_11 (date VARCHAR, away_team VARCHAR)',\n",
" 'answer': 'SELECT date FROM table_name_11 WHERE away_team = \"essendon\"'}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset['test'][0]"
]
},
{
"cell_type": "markdown",
"id": "e8a79425",
"metadata": {},
"source": [
"# Preprocess the Datasets\n",
"\n",
"You need to convert the datasets into explicit instructions for the LLM.\n",
"\n",
"Then preprocess the prompt-response dataset into tokens and pull out their input_ids."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ad26693b",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/118695 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/14835 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/14838 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5dd41fae15dd43d5bd3bc5a44bcb2603",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Saving the dataset (0/2 shards): 0%| | 0/118695 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Saving the dataset (0/1 shards): 0%| | 0/14835 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Saving the dataset (0/1 shards): 0%| | 0/14838 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tokenized and Saved Dataset\n"
]
}
],
"source": [
"def tokenize_function(example):\n",
" \n",
"# print(len(example[\"question\"]))\n",
" start_prompt = \"Tables:\\n\"\n",
" middle_prompt = \"\\n\\nQuestion:\\n\"\n",
" end_prompt = \"\\n\\nAnswer:\\n\"\n",
" \n",
" data_zip = zip(example['context'], example['question'])\n",
" prompt = [start_prompt + context + middle_prompt + question + end_prompt for context, question in data_zip]\n",
" example['input_ids'] = tokenizer(prompt, padding=\"max_length\", truncation=True, return_tensors=\"pt\").input_ids\n",
" example['labels'] = tokenizer(example['answer'], padding=\"max_length\", truncation=True, return_tensors=\"pt\").input_ids\n",
"# print(prompt[0])\n",
"# print()\n",
" \n",
" return example\n",
"\n",
"# The dataset actually contains 3 diff splits: train, validation, test.\n",
"# The tokenize_function code is handling all data across all splits in batches.\n",
"\n",
"try:\n",
" tokenized_datasets = load_from_disk(\"tokenized_datasets\")\n",
" print(\"Loaded Tokenized Dataset\")\n",
"except:\n",
" tokenized_datasets = dataset.map(tokenize_function, batched=True)\n",
" tokenized_datasets = tokenized_datasets.remove_columns(['question', 'context', 'answer'])\n",
" \n",
" tokenized_datasets.save_to_disk(\"tokenized_datasets\")\n",
" print(\"Tokenized and Saved Dataset\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fe4bfa16",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_keys(['train', 'test', 'validation'])\n",
"dict_keys(['input_ids', 'labels'])\n",
"[4398, 7, 10, 205, 4386, 6048, 332, 17098, 819, 41]\n",
"[3, 23143, 14196, 2847, 17161, 599, 1935, 61, 21680, 819]\n",
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['input_ids', 'labels'],\n",
" num_rows: 118695\n",
" })\n",
" test: Dataset({\n",
" features: ['input_ids', 'labels'],\n",
" num_rows: 14835\n",
" })\n",
" validation: Dataset({\n",
" features: ['input_ids', 'labels'],\n",
" num_rows: 14838\n",
" })\n",
"})\n"
]
}
],
"source": [
"print(tokenized_datasets.keys())\n",
"print(tokenized_datasets['train'][0].keys())\n",
"print(tokenized_datasets['train'][0]['input_ids'][:10])\n",
"print(tokenized_datasets['train'][0]['labels'][:10])\n",
"print(tokenized_datasets)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "6efaa5b9",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shapes of the datasets:\n",
"Training: (118695, 2)\n",
"Validation: (14838, 2)\n",
"Test: (14835, 2)\n",
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['input_ids', 'labels'],\n",
" num_rows: 118695\n",
" })\n",
" test: Dataset({\n",
" features: ['input_ids', 'labels'],\n",
" num_rows: 14835\n",
" })\n",
" validation: Dataset({\n",
" features: ['input_ids', 'labels'],\n",
" num_rows: 14838\n",
" })\n",
"})\n"
]
}
],
"source": [
"print(f\"Shapes of the datasets:\")\n",
"print(f\"Training: {tokenized_datasets['train'].shape}\")\n",
"print(f\"Validation: {tokenized_datasets['validation'].shape}\")\n",
"print(f\"Test: {tokenized_datasets['test'].shape}\")\n",
"\n",
"print(tokenized_datasets)"
]
},
{
"cell_type": "markdown",
"id": "4e52f581",
"metadata": {},
"source": [
"# Test the Model with Zero Shot Inferencing"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1c6a5c0f",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------------------------------------------------------------------------------------------------\n",
"INPUT PROMPT:\n",
"Tables:\n",
"CREATE TABLE table_name_11 (date VARCHAR, away_team VARCHAR)\n",
"\n",
"Question:\n",
"On what Date did the Away team essendon play?\n",
"\n",
"Answer:\n",
"\n",
"---------------------------------------------------------------------------------------------------\n",
"BASELINE HUMAN ANSWER:\n",
"SELECT date FROM table_name_11 WHERE away_team = \"essendon\"\n",
"\n",
"---------------------------------------------------------------------------------------------------\n",
"MODEL GENERATION - ZERO SHOT:\n",
"Question\n"
]
}
],
"source": [
"index = 0\n",
"\n",
"question = dataset['test'][index]['question']\n",
"context = dataset['test'][index]['context']\n",
"answer = dataset['test'][index]['answer']\n",
"\n",
"prompt = f\"\"\"Tables:\n",
"{context}\n",
"\n",
"Question:\n",
"{question}\n",
"\n",
"Answer:\n",
"\"\"\"\n",
"\n",
"inputs = tokenizer(prompt, return_tensors='pt')\n",
"inputs = inputs.to('cuda')\n",
"\n",
"output = tokenizer.decode(\n",
" original_model.generate(\n",
" inputs[\"input_ids\"], \n",
" max_new_tokens=200,\n",
" )[0], \n",
" skip_special_tokens=True\n",
")\n",
"\n",
"dash_line = '-'.join('' for x in range(100))\n",
"print(dash_line)\n",
"print(f'INPUT PROMPT:\\n{prompt}')\n",
"print(dash_line)\n",
"print(f'BASELINE HUMAN ANSWER:\\n{answer}\\n')\n",
"print(dash_line)\n",
"print(f'MODEL GENERATION - ZERO SHOT:\\n{output}')"
]
},
{
"cell_type": "markdown",
"id": "a22a7f40",
"metadata": {},
"source": [
"So pretty poor - aka garbage."
]
},
{
"cell_type": "markdown",
"id": "c8832fa9",
"metadata": {},
"source": [
"# Perform Full Fine-Tuning"
]
},
{
"cell_type": "markdown",
"id": "24f2d995",
"metadata": {},
"source": [
"### 2 Epochs\n",
"\n",
"#### 5e-3\n",
"\n",
"Time Taken = 2h 49m 1s on a laptop with a GeForce RTX 3070 GPU\n",
"\n",
"Training Loss = 0.023100\n",
"\n",
"Validation Loss = 0.013285"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "94988713",
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(\"finetuned_model_2_epoch\")\n",
" finetuned_model = finetuned_model.to('cuda')\n",
" to_train = False\n",
"\n",
"except:\n",
" to_train = True\n",
" finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n",
" finetuned_model = finetuned_model.to('cuda')\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "ba6d32dd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: total: 0 ns\n",
"Wall time: 0 ns\n"
]
}
],
"source": [
"%%time\n",
"\n",
"if to_train:\n",
" output_dir = f'./sql-training-{str(int(time.time()))}'\n",
"\n",
" training_args = TrainingArguments(\n",
" output_dir=output_dir,\n",
" learning_rate=5e-3,\n",
" num_train_epochs=2,\n",
" per_device_train_batch_size=16, # batch size per device during training\n",
" per_device_eval_batch_size=16, # batch size for evaluation\n",
" weight_decay=0.01,\n",
" logging_steps=50,\n",
" evaluation_strategy='steps', # evaluation strategy to adopt during training\n",
" eval_steps=500, # number of steps between evaluation\n",
" )\n",
"\n",
" trainer = Trainer(\n",
" model=finetuned_model,\n",
" args=training_args,\n",
" train_dataset=tokenized_datasets['train'],\n",
" eval_dataset=tokenized_datasets['validation'],\n",
" )\n",
" \n",
" trainer.train()\n",
" \n",
" finetuned_model.save_pretrained(\"finetuned_model_2_epoch\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "4507aa94",
"metadata": {},
"outputs": [],
"source": [
"finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(\"finetuned_model_2_epoch\")\n",
"finetuned_model = finetuned_model.to('cuda')"
]
},
{
"cell_type": "markdown",
"id": "131bc210",
"metadata": {},
"source": [
"# Test the Fine Tuned Model with Zero Shot Inferencing"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "f3fdfcf5",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------------------------------------------------------------------------------------------------\n",
"INPUT PROMPT:\n",
"Tables:\n",
"CREATE TABLE employees (\n",
" EMPLOYEE_ID decimal(6,0),\n",
" FIRST_NAME varchar(20),\n",
" LAST_NAME varchar(25),\n",
" EMAIL varchar(25),\n",
" PHONE_NUMBER varchar(20),\n",
" HIRE_DATE date,\n",
" JOB_ID varchar(10),\n",
" SALARY decimal(8,2),\n",
" COMMISSION_PCT decimal(2,2),\n",
" MANAGER_ID decimal(6,0),\n",
" DEPARTMENT_ID decimal(4,0)\n",
")\n",
"\n",
"CREATE TABLE jobs (\n",
" JOB_ID varchar(10),\n",
" JOB_TITLE varchar(35),\n",
" MIN_SALARY decimal(6,0),\n",
" MAX_SALARY decimal(6,0)\n",
")\n",
"\n",
"CREATE TABLE locations (\n",
" LOCATION_ID decimal(4,0),\n",
" STREET_ADDRESS varchar(40),\n",
" POSTAL_CODE varchar(12),\n",
" CITY varchar(30),\n",
" STATE_PROVINCE varchar(25),\n",
" COUNTRY_ID varchar(2)\n",
")\n",
"\n",
"CREATE TABLE countries (\n",
" COUNTRY_ID varchar(2),\n",
" COUNTRY_NAME varchar(40),\n",
" REGION_ID decimal(10,0)\n",
")\n",
"\n",
"CREATE TABLE job_history (\n",
" EMPLOYEE_ID decimal(6,0),\n",
" START_DATE date,\n",
" END_DATE date,\n",
" JOB_ID varchar(10),\n",
" DEPARTMENT_ID decimal(4,0)\n",
")\n",
"\n",
"CREATE TABLE regions (\n",
" REGION_ID decimal(5,0),\n",
" REGION_NAME varchar(25)\n",
")\n",
"\n",
"CREATE TABLE departments (\n",
" DEPARTMENT_ID decimal(4,0),\n",
" DEPARTMENT_NAME varchar(30),\n",
" MANAGER_ID decimal(6,0),\n",
" LOCATION_ID decimal(4,0)\n",
")\n",
"\n",
"Question:\n",
"For those employees who did not have any job in the past, give me the comparison about the amount of job_id over the job_id , and group by attribute job_id, and list from low to high by the JOB_ID please.\n",
"\n",
"Answer:\n",
"\n",
"---------------------------------------------------------------------------------------------------\n",
"BASELINE HUMAN ANSWER:\n",
"SELECT JOB_ID, COUNT(JOB_ID) FROM employees WHERE NOT EMPLOYEE_ID IN (SELECT EMPLOYEE_ID FROM job_history) GROUP BY JOB_ID ORDER BY JOB_ID\n",
"\n",
"---------------------------------------------------------------------------------------------------\n",
"FINE-TUNED MODEL - ZERO SHOT:\n",
"SELECT JOB_ID, COUNT(JOB_ID) FROM employees WHERE NOT EMPLOYEE_ID IN (SELECT EMPLOYEE_ID FROM job_history) GROUP BY JOB_ID ORDER BY JOB_ID\n"
]
}
],
"source": [
"index = 0\n",
"# index = len(dataset['test'])-200\n",
"\n",
"question = dataset['test'][index]['question']\n",
"context = dataset['test'][index]['context']\n",
"answer = dataset['test'][index]['answer']\n",
"\n",
"prompt = f\"\"\"Tables:\n",
"{context}\n",
"\n",
"Question:\n",
"{question}\n",
"\n",
"Answer:\n",
"\"\"\"\n",
"\n",
"inputs = tokenizer(prompt, return_tensors='pt')\n",
"inputs = inputs.to('cuda')\n",
"\n",
"output = tokenizer.decode(\n",
" finetuned_model.generate(\n",
" inputs[\"input_ids\"], \n",
" max_new_tokens=200,\n",
" )[0], \n",
" skip_special_tokens=True\n",
")\n",
"\n",
"dash_line = '-'.join('' for x in range(100))\n",
"print(dash_line)\n",
"print(f'INPUT PROMPT:\\n{prompt}')\n",
"print(dash_line)\n",
"print(f'BASELINE HUMAN ANSWER:\\n{answer}\\n')\n",
"print(dash_line)\n",
"print(f'FINE-TUNED MODEL - ZERO SHOT:\\n{output}')"
]
},
{
"cell_type": "markdown",
"id": "69ec82ff",
"metadata": {},
"source": [
"# Evaluate the Model Quantitatively (with ROUGE Metric)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "8e665b3b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Token indices sequence length is longer than the specified maximum sequence length for this model (1115 > 512). Running this sequence through the model will result in indexing errors\n"
]
}
],
"source": [
"# Perform inferences for test dataset. Do 25 only, due to time it takes.\n",
"\n",
"questions = dataset['test'][0:25]['question']\n",
"contexts = dataset['test'][0:25]['context']\n",
"human_baseline_answers = dataset['test'][0:25]['answer']\n",
"\n",
"original_model_answers = []\n",
"finetuned_model_answers = []\n",
"\n",
"for idx, question in enumerate(questions):\n",
" \n",
" prompt = f\"\"\"Tables:\n",
"{contexts[idx]}\n",
"\n",
"Question:\n",
"{question}\n",
"\n",
"Answer:\n",
"\"\"\"\n",
" \n",
" input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
" input_ids = input_ids.to('cuda')\n",
"\n",
" human_baseline_text_output = human_baseline_answers[idx]\n",
" \n",
" original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=300))\n",
" original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)\n",
" original_model_answers.append(original_model_text_output)\n",
" \n",
" finetuned_model_outputs = finetuned_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=300))\n",
" finetuned_model_text_output = tokenizer.decode(finetuned_model_outputs[0], skip_special_tokens=True)\n",
" finetuned_model_answers.append(finetuned_model_text_output)\n",
"\n",
"zipped_summaries = list(zip(human_baseline_answers, original_model_answers, finetuned_model_answers))\n",
" \n",
"df = pd.DataFrame(zipped_summaries, columns = ['human_baseline_answers', 'original_model_answers', 'finetuned_model_answers'])\n",
"# df"
]
},
{
"cell_type": "markdown",
"id": "b00766ae",
"metadata": {},
"source": [
"Compute ROUGE score for this subset of the data."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "18975f9d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ORIGINAL MODEL:\n",
"{'rouge1': 0.031970284742291306, 'rouge2': 0.005, 'rougeL': 0.03070044347245003, 'rougeLsum': 0.03121247624254732}\n",
"FINE-TUNED MODEL:\n",
"{'rouge1': 0.923359923692127, 'rouge2': 0.8863291968739871, 'rougeL': 0.9176464597549342, 'rougeLsum': 0.9182149521321223}\n"
]
}
],
"source": [
"rouge = evaluate.load('rouge')\n",
"\n",
"original_model_results = rouge.compute(\n",
" predictions=original_model_answers,\n",
" references=human_baseline_answers[0:len(original_model_answers)],\n",
" use_aggregator=True,\n",
" use_stemmer=True,\n",
")\n",
"print('ORIGINAL MODEL:')\n",
"print(original_model_results)\n",
"\n",
"\n",
"finetuned_model_results = rouge.compute(\n",
" predictions=finetuned_model_answers,\n",
" references=human_baseline_answers[0:len(finetuned_model_answers)],\n",
" use_aggregator=True,\n",
" use_stemmer=True,\n",
")\n",
"print('FINE-TUNED MODEL:')\n",
"print(finetuned_model_results)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc7ef16d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment