Skip to content

Instantly share code, notes, and snippets.

@zredlined
Last active May 1, 2024 02:00
Show Gist options
  • Save zredlined/50ef5cbf09aa63302f963e14dc1bf5a9 to your computer and use it in GitHub Desktop.
Save zredlined/50ef5cbf09aa63302f963e14dc1bf5a9 to your computer and use it in GitHub Desktop.
Gretel Tabular LLM - Textbooks notebook for workshop
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/zredlined/50ef5cbf09aa63302f963e14dc1bf5a9/notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"<br>\n",
"\n",
"<center><a href=https://gretel.ai/><img src=\"https://gretel-public-website.s3.us-west-2.amazonaws.com/assets/brand/gretel_brand_wordmark.svg\" alt=\"Gretel\" width=\"350\"/></a></center>\n",
"\n",
"<br>\n",
"\n",
"\n",
"# 🚀 Augmenting LLM training data with high-quality _synthetic_ examples\n",
"In this notebook, we will leverage [Gretel's Tabular LLM](https://gretel.ai/tabular-llm) to generate diverse, high-quality training examples to efficiently train/fine-tune better LLMs with less data. Our goal is to demonstrate how to get started creating high-quality synthetic data for LLM training and facilitate further research into safeguards for completion models.\n",
"\n",
"## Background\n",
"Recent research has shown that training small, efficient language models on high-quality, diverse data can achieve state-of-the-art results, as demonstrated by Microsoft's [phi-1.5](https://arxiv.org/abs/2309.05463) and [Orca2](https://arxiv.org/abs/2311.11045) models.\n",
"\n",
"Creating diverse synthetic training data is challenging but vital to reduce overfitting and improve generalization. We will demonstrate how to boost generation diversity using an approach similar to the [TinyStories](https://arxiv.org/abs/2305.07759) study, in which the authors chose random words from a fixed vocabulary to inject into the prompt.\n",
"\n",
"## Prerequisites\n",
"\n",
"Before diving into the notebook, there are a couple of prerequisites:\n",
"\n",
"1. **Gretel API Key**: You'll need an API key from Gretel. If you don't have one already, you can obtain it from [Gretel's console]((https://console.gretel.ai/users/me/key)). This key will enable us to use Gretel's services for generating our synthetic datasets.\n",
"\n",
"2. **Access to Gretel's Tabular LLM**: To utilize the specific features of the Tabular LLM, you need to have access to the early preview. If you're not already signed up, you can request early access at [Gretel's Tabular LLM page](https://gretel.ai/tabular-llm).\n",
"\n",
"Let's get started!\n"
],
"metadata": {
"id": "oXMAurxdjngb"
}
},
{
"cell_type": "markdown",
"source": [
"## 💾 Install and import necessary packages"
],
"metadata": {
"id": "3tMfsaYPjt1p"
}
},
{
"cell_type": "code",
"source": [
"%%capture\n",
"!pip install datasets keybert keyphrase_vectorizers\n",
"!pip install -Uqq gretel_client"
],
"metadata": {
"id": "9vcn14--cbq1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TCQQA5M2cXvV"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from datasets import load_dataset\n",
"from keybert import KeyBERT\n",
"from keyphrase_vectorizers import KeyphraseCountVectorizer\n",
"\n",
"from gretel_client import Gretel"
]
},
{
"cell_type": "markdown",
"source": [
"## 🛜 Configure your Gretel session and initialize Tabular LLM\n",
"\n",
"- Running the cell below will prompt you for your Gretel API key, which you can retrieve [here](https://console.gretel.ai/users/me/key)."
],
"metadata": {
"id": "ENUkdyRlj4D0"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iNuaoVQFcXvW"
},
"outputs": [],
"source": [
"gretel = Gretel(api_key=\"prompt\")\n",
"\n",
"tabllm = gretel.factories.initialize_inference_api(backend_model=\"gretelai/tabular-v0c\")"
]
},
{
"cell_type": "markdown",
"source": [
"## ⚙️ Set demo parameters"
],
"metadata": {
"id": "WZRoXR3ZpG1T"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3ZldYHe2cXvW"
},
"outputs": [],
"source": [
"DATASET_NAME = \"databricks/databricks-dolly-15k\"\n",
"CATEGORY = \"closed_qa\"\n",
"MAX_WORDS = 400\n",
"NUM_EXAMPLES = 10\n",
"NUM_SELECT_PHRASES = 2\n",
"UPSAMPLE_MULTIPLE = 3\n",
"RANDOM_SEED = len(\"GRETEL\")"
]
},
{
"cell_type": "markdown",
"source": [
"## 💾 Load and preprocess dataset\n",
"\n",
"In the cell below we perform the following preprocessing steps:\n",
"\n",
"- Load the Dolly dataset and convert to a pandas `DataFrame`\n",
"\n",
"- Select examples in the set `CATEGORY`\n",
"\n",
"- Clean text and convert to ascii\n",
"\n",
"- Remove examples with more words than `MAX_WORDS`\n",
"\n",
"- Drop unnecessary columns\n",
"\n",
"- Sample `NUM_EXAMPLES` examples for the demo"
],
"metadata": {
"id": "DAGa-3kFpQ8u"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DqLbdq9dcXvX"
},
"outputs": [],
"source": [
"%%capture\n",
"dataset = load_dataset(DATASET_NAME, split=\"train\")\n",
"\n",
"df = (\n",
" dataset\n",
" .to_pandas()\n",
" .query(\"category==@CATEGORY\")\n",
" .applymap(lambda x: x.replace('\\n', ' ').replace('\\r', ' ').encode('ascii', 'ignore').decode('ascii'))\n",
" .assign(num_words=lambda df_: df_[\"context\"].str.cat(df_[\"response\"], sep=\" \").str.split().apply(len))\n",
" .query(\"num_words < @MAX_WORDS\")\n",
" .drop(columns=[\"category\", \"num_words\"])\n",
" .sample(NUM_EXAMPLES, random_state=RANDOM_SEED)\n",
" .reset_index(drop=True)\n",
")"
]
},
{
"cell_type": "code",
"source": [
"tabllm.display_dataframe_in_notebook(df)"
],
"metadata": {
"id": "BiIO7__Mwet-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 🗝️ Identify key phrases\n",
"\n",
"- Here we use a BERT-based model to extract interesting/important key phrases from the `context` of each example.\n",
"\n",
"- We upsample the dataset by a factor of `UPSAMPLE_MULTIPLE`, which will allow us to create multiple examples from the same context.\n",
"\n",
"- We then sample `NUM_SELECT_PHRASES` from the extracted key phrases.\n",
"\n",
"- We will use these `select_phrases` to boost the diversity in our synthetic instructions."
],
"metadata": {
"id": "osYtwRmdqZVK"
}
},
{
"cell_type": "code",
"source": [
"%%capture\n",
"np.random.seed(RANDOM_SEED)\n",
"\n",
"def sample_key_phrases(phrases):\n",
" random_idx = np.random.choice(len(phrases), NUM_SELECT_PHRASES, replace=False)\n",
" return \", \".join([phrases[i][0] for i in random_idx])\n",
"\n",
"\n",
"df[\"select_phrases\"] = KeyBERT().extract_keywords(\n",
" docs=df[\"context\"].tolist(),\n",
" vectorizer=KeyphraseCountVectorizer(), top_n=3 * NUM_SELECT_PHRASES\n",
")\n",
"\n",
"df = pd.DataFrame(np.repeat(df.values, UPSAMPLE_MULTIPLE, axis=0), columns=df.columns)\n",
"\n",
"df[\"select_phrases\"] = df[\"select_phrases\"].apply(sample_key_phrases)"
],
"metadata": {
"id": "7fH6j_wFgd-b"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# preview example context + select_phrases\n",
"tabllm.display_dataframe_in_notebook(df[[\"context\", \"select_phrases\"]].head())"
],
"metadata": {
"id": "lf3titL6gf-T"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 🤖 Prompt Gretel's Tabular LLM to create synthetic instructions and responses"
],
"metadata": {
"id": "2c2MUSKJu6o_"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tCt4Yj-TcXvX"
},
"outputs": [],
"source": [
"prompt = \"\"\"\\\n",
"For each example in the Dataset, please act as a tutor and create high quality,\n",
"detailed synthetic question and answers of higher quality than the provided example.\n",
"Frame your question around one of the phrases in the 'select_phrases' column.\n",
"Ensure the data teaches concepts step-by-step and focuses on improving reasoning skills.\n",
"Focus on generating questions and answers about under-represented topics and knowledge gaps.\n",
"\n",
"Add two new columns to the Dataset:\n",
"1. 'synthetic_instruction':\n",
" * Introduce the topic from the example briefly in 1-2 sentences\n",
" * Ask a clear question related to the topic that requires logical thinking or common sense reasoning\n",
" * Provide any necessary context to set up the reasoning problem\n",
" * Do not repeat the instruction from the Dataset example\n",
"2. 'synthetic_response':\n",
" * Respond to the synthetically generated instruction thoroughly in a step-by-step manner\n",
" * Provide the complete reasoning needed to arrive at the answer\n",
" * Ensure the explanation is textbook quality with all details needed to learn the concept\n",
" * Answer in 3-5 sentences.\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TSTBhQX6cXvY"
},
"outputs": [],
"source": [
"columns = [\"instruction\", \"context\", \"response\", \"synthetic_instruction\", \"synthetic_response\"]\n",
"\n",
"synthetic = tabllm.edit(prompt=prompt, seed_data=df, top_k=40, temperature=0.8)[columns]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6PBgvz0bcXvZ"
},
"outputs": [],
"source": [
"tabllm.display_dataframe_in_notebook(synthetic)"
]
},
{
"cell_type": "code",
"source": [
"# Comparing quality\n",
"import IPython\n",
"\n",
"from gretel_client.evaluation.text_quality_report import TextQualityReport\n",
"\n",
"real=pd.DataFrame()\n",
"real['text'] = synthetic['instruction'] + \" \" + synthetic['response']\n",
"synthetic['text'] = synthetic['synthetic_instruction'] + \" \" + synthetic['synthetic_response']\n",
"\n",
"report = TextQualityReport(data_source=synthetic,\n",
" ref_data=real,\n",
" target='text',\n",
" record_count=len(synthetic))\n",
"report.run()\n",
"\n",
"IPython.display.HTML(report.as_html, metadata=dict(isolated=True))"
],
"metadata": {
"id": "iA172X4xYg0c"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"display_name": "monogretel-dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"colab": {
"provenance": [],
"name": "Gretel Tabular LLM - Textbooks notebook for workshop",
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment