-
-
Save zredlined/50ef5cbf09aa63302f963e14dc1bf5a9 to your computer and use it in GitHub Desktop.
Gretel Tabular LLM - Textbooks notebook for workshop
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/zredlined/50ef5cbf09aa63302f963e14dc1bf5a9/notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"<br>\n", | |
"\n", | |
"<center><a href=https://gretel.ai/><img src=\"https://gretel-public-website.s3.us-west-2.amazonaws.com/assets/brand/gretel_brand_wordmark.svg\" alt=\"Gretel\" width=\"350\"/></a></center>\n", | |
"\n", | |
"<br>\n", | |
"\n", | |
"\n", | |
"# 🚀 Augmenting LLM training data with high-quality _synthetic_ examples\n", | |
"In this notebook, we will leverage [Gretel's Tabular LLM](https://gretel.ai/tabular-llm) to generate diverse, high-quality training examples to efficiently train/fine-tune better LLMs with less data. Our goal is to demonstrate how to get started creating high-quality synthetic data for LLM training and facilitate further research into safeguards for completion models.\n", | |
"\n", | |
"## Background\n", | |
"Recent research has shown that training small, efficient language models on high-quality, diverse data can achieve state-of-the-art results, as demonstrated by Microsoft's [phi-1.5](https://arxiv.org/abs/2309.05463) and [Orca2](https://arxiv.org/abs/2311.11045) models.\n", | |
"\n", | |
"Creating diverse synthetic training data is challenging but vital to reduce overfitting and improve generalization. We will demonstrate how to boost generation diversity using an approach similar to the [TinyStories](https://arxiv.org/abs/2305.07759) study, in which the authors chose random words from a fixed vocabulary to inject into the prompt.\n", | |
"\n", | |
"## Prerequisites\n", | |
"\n", | |
"Before diving into the notebook, there are a couple of prerequisites:\n", | |
"\n", | |
"1. **Gretel API Key**: You'll need an API key from Gretel. If you don't have one already, you can obtain it from [Gretel's console]((https://console.gretel.ai/users/me/key)). This key will enable us to use Gretel's services for generating our synthetic datasets.\n", | |
"\n", | |
"2. **Access to Gretel's Tabular LLM**: To utilize the specific features of the Tabular LLM, you need to have access to the early preview. If you're not already signed up, you can request early access at [Gretel's Tabular LLM page](https://gretel.ai/tabular-llm).\n", | |
"\n", | |
"Let's get started!\n" | |
], | |
"metadata": { | |
"id": "oXMAurxdjngb" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 💾 Install and import necessary packages" | |
], | |
"metadata": { | |
"id": "3tMfsaYPjt1p" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%%capture\n", | |
"!pip install datasets keybert keyphrase_vectorizers\n", | |
"!pip install -Uqq gretel_client" | |
], | |
"metadata": { | |
"id": "9vcn14--cbq1" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "TCQQA5M2cXvV" | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"\n", | |
"from datasets import load_dataset\n", | |
"from keybert import KeyBERT\n", | |
"from keyphrase_vectorizers import KeyphraseCountVectorizer\n", | |
"\n", | |
"from gretel_client import Gretel" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 🛜 Configure your Gretel session and initialize Tabular LLM\n", | |
"\n", | |
"- Running the cell below will prompt you for your Gretel API key, which you can retrieve [here](https://console.gretel.ai/users/me/key)." | |
], | |
"metadata": { | |
"id": "ENUkdyRlj4D0" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "iNuaoVQFcXvW" | |
}, | |
"outputs": [], | |
"source": [ | |
"gretel = Gretel(api_key=\"prompt\")\n", | |
"\n", | |
"tabllm = gretel.factories.initialize_inference_api(backend_model=\"gretelai/tabular-v0c\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## ⚙️ Set demo parameters" | |
], | |
"metadata": { | |
"id": "WZRoXR3ZpG1T" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "3ZldYHe2cXvW" | |
}, | |
"outputs": [], | |
"source": [ | |
"DATASET_NAME = \"databricks/databricks-dolly-15k\"\n", | |
"CATEGORY = \"closed_qa\"\n", | |
"MAX_WORDS = 400\n", | |
"NUM_EXAMPLES = 10\n", | |
"NUM_SELECT_PHRASES = 2\n", | |
"UPSAMPLE_MULTIPLE = 3\n", | |
"RANDOM_SEED = len(\"GRETEL\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 💾 Load and preprocess dataset\n", | |
"\n", | |
"In the cell below we perform the following preprocessing steps:\n", | |
"\n", | |
"- Load the Dolly dataset and convert to a pandas `DataFrame`\n", | |
"\n", | |
"- Select examples in the set `CATEGORY`\n", | |
"\n", | |
"- Clean text and convert to ascii\n", | |
"\n", | |
"- Remove examples with more words than `MAX_WORDS`\n", | |
"\n", | |
"- Drop unnecessary columns\n", | |
"\n", | |
"- Sample `NUM_EXAMPLES` examples for the demo" | |
], | |
"metadata": { | |
"id": "DAGa-3kFpQ8u" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "DqLbdq9dcXvX" | |
}, | |
"outputs": [], | |
"source": [ | |
"%%capture\n", | |
"dataset = load_dataset(DATASET_NAME, split=\"train\")\n", | |
"\n", | |
"df = (\n", | |
" dataset\n", | |
" .to_pandas()\n", | |
" .query(\"category==@CATEGORY\")\n", | |
" .applymap(lambda x: x.replace('\\n', ' ').replace('\\r', ' ').encode('ascii', 'ignore').decode('ascii'))\n", | |
" .assign(num_words=lambda df_: df_[\"context\"].str.cat(df_[\"response\"], sep=\" \").str.split().apply(len))\n", | |
" .query(\"num_words < @MAX_WORDS\")\n", | |
" .drop(columns=[\"category\", \"num_words\"])\n", | |
" .sample(NUM_EXAMPLES, random_state=RANDOM_SEED)\n", | |
" .reset_index(drop=True)\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"tabllm.display_dataframe_in_notebook(df)" | |
], | |
"metadata": { | |
"id": "BiIO7__Mwet-" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 🗝️ Identify key phrases\n", | |
"\n", | |
"- Here we use a BERT-based model to extract interesting/important key phrases from the `context` of each example.\n", | |
"\n", | |
"- We upsample the dataset by a factor of `UPSAMPLE_MULTIPLE`, which will allow us to create multiple examples from the same context.\n", | |
"\n", | |
"- We then sample `NUM_SELECT_PHRASES` from the extracted key phrases.\n", | |
"\n", | |
"- We will use these `select_phrases` to boost the diversity in our synthetic instructions." | |
], | |
"metadata": { | |
"id": "osYtwRmdqZVK" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%%capture\n", | |
"np.random.seed(RANDOM_SEED)\n", | |
"\n", | |
"def sample_key_phrases(phrases):\n", | |
" random_idx = np.random.choice(len(phrases), NUM_SELECT_PHRASES, replace=False)\n", | |
" return \", \".join([phrases[i][0] for i in random_idx])\n", | |
"\n", | |
"\n", | |
"df[\"select_phrases\"] = KeyBERT().extract_keywords(\n", | |
" docs=df[\"context\"].tolist(),\n", | |
" vectorizer=KeyphraseCountVectorizer(), top_n=3 * NUM_SELECT_PHRASES\n", | |
")\n", | |
"\n", | |
"df = pd.DataFrame(np.repeat(df.values, UPSAMPLE_MULTIPLE, axis=0), columns=df.columns)\n", | |
"\n", | |
"df[\"select_phrases\"] = df[\"select_phrases\"].apply(sample_key_phrases)" | |
], | |
"metadata": { | |
"id": "7fH6j_wFgd-b" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# preview example context + select_phrases\n", | |
"tabllm.display_dataframe_in_notebook(df[[\"context\", \"select_phrases\"]].head())" | |
], | |
"metadata": { | |
"id": "lf3titL6gf-T" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 🤖 Prompt Gretel's Tabular LLM to create synthetic instructions and responses" | |
], | |
"metadata": { | |
"id": "2c2MUSKJu6o_" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "tCt4Yj-TcXvX" | |
}, | |
"outputs": [], | |
"source": [ | |
"prompt = \"\"\"\\\n", | |
"For each example in the Dataset, please act as a tutor and create high quality,\n", | |
"detailed synthetic question and answers of higher quality than the provided example.\n", | |
"Frame your question around one of the phrases in the 'select_phrases' column.\n", | |
"Ensure the data teaches concepts step-by-step and focuses on improving reasoning skills.\n", | |
"Focus on generating questions and answers about under-represented topics and knowledge gaps.\n", | |
"\n", | |
"Add two new columns to the Dataset:\n", | |
"1. 'synthetic_instruction':\n", | |
" * Introduce the topic from the example briefly in 1-2 sentences\n", | |
" * Ask a clear question related to the topic that requires logical thinking or common sense reasoning\n", | |
" * Provide any necessary context to set up the reasoning problem\n", | |
" * Do not repeat the instruction from the Dataset example\n", | |
"2. 'synthetic_response':\n", | |
" * Respond to the synthetically generated instruction thoroughly in a step-by-step manner\n", | |
" * Provide the complete reasoning needed to arrive at the answer\n", | |
" * Ensure the explanation is textbook quality with all details needed to learn the concept\n", | |
" * Answer in 3-5 sentences.\n", | |
"\"\"\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "TSTBhQX6cXvY" | |
}, | |
"outputs": [], | |
"source": [ | |
"columns = [\"instruction\", \"context\", \"response\", \"synthetic_instruction\", \"synthetic_response\"]\n", | |
"\n", | |
"synthetic = tabllm.edit(prompt=prompt, seed_data=df, top_k=40, temperature=0.8)[columns]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "6PBgvz0bcXvZ" | |
}, | |
"outputs": [], | |
"source": [ | |
"tabllm.display_dataframe_in_notebook(synthetic)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Comparing quality\n", | |
"import IPython\n", | |
"\n", | |
"from gretel_client.evaluation.text_quality_report import TextQualityReport\n", | |
"\n", | |
"real=pd.DataFrame()\n", | |
"real['text'] = synthetic['instruction'] + \" \" + synthetic['response']\n", | |
"synthetic['text'] = synthetic['synthetic_instruction'] + \" \" + synthetic['synthetic_response']\n", | |
"\n", | |
"report = TextQualityReport(data_source=synthetic,\n", | |
" ref_data=real,\n", | |
" target='text',\n", | |
" record_count=len(synthetic))\n", | |
"report.run()\n", | |
"\n", | |
"IPython.display.HTML(report.as_html, metadata=dict(isolated=True))" | |
], | |
"metadata": { | |
"id": "iA172X4xYg0c" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "monogretel-dev", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
}, | |
"colab": { | |
"provenance": [], | |
"name": "Gretel Tabular LLM - Textbooks notebook for workshop", | |
"include_colab_link": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment