-
-
Save suneel-pi/811ae833c8a3597dee7808a23176d1c5 to your computer and use it in GitHub Desktop.
pi-scorer-on-unsloth.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/suneel-pi/811ae833c8a3597dee7808a23176d1c5/pi-scorer-on-unsloth.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "pi-masthead" | |
| }, | |
| "source": [ | |
| "<a href=\"https://withpi.ai\"><img src=\"https://play.withpi.ai/logo/logoFullBlack.svg\" width=\"240\"></a>\n", | |
| "\n", | |
| "<a href=\"https://code.withpi.ai\"><font size=\"4\">Documentation</font></a>\n", | |
| "\n", | |
| "<a href=\"https://build.withpi.ai\"><font size=\"4\">Copilot</font></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "KiS-7CbBYS0D" | |
| }, | |
| "source": [ | |
| "Are you constantly relying on LLM-as-a-judge to evaluate your model’s performance?\n", | |
| "\n", | |
| "Have you ever wanted to assess your model at every training checkpoint but hesitated because LLM-as-a-judge is too slow and expensive?\n", | |
| "\n", | |
| "**Now you can — with [Pi-Scorer](https://build.withpi.ai).**\n", | |
| "\n", | |
| "[Pi-Scorer](https://build.withpi.ai) offers an alternative to LLM-as-a-judge with several advantages:\n", | |
| "\n", | |
| "* Significantly faster\n", | |
| "\n", | |
| "* Highly consistent — always returns the same score for the same inputs\n", | |
| "\n", | |
| "* Eliminates the need for prompt tuning or adjustments\n", | |
| "\n", | |
| "In this Colab, we integrate [Pi-Scorer](https://build.withpi.ai) as the reward function within the [Unsloth](https://unsloth.ai/) GRPO training loop, based on the [Unsloth Qwen2.5_(3B)-GRPO.ipynb colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_(3B)-GRPO.ipynb) notebook." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "NVoflExGYvu6" | |
| }, | |
| "source": [ | |
| "### Installation" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "cuA79HQOBfWX" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from google.colab import userdata\n", | |
| "import os\n", | |
| "\n", | |
| "# Get PI API key: https://build.withpi.ai/account/keys\n", | |
| "os.environ[\"WITHPI_API_KEY\"] = userdata.get('WITHPI_API_KEY')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "pfP5p84HYvu6" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "%%capture\n", | |
| "import os\n", | |
| "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n", | |
| " !pip install unsloth vllm\n", | |
| "else:\n", | |
| " # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n", | |
| " !pip install --no-deps unsloth vllm==0.8.5.post1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "aC-VIBYpYvu6" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "#@title Colab Extra Install { display-mode: \"form\" }\n", | |
| "%%capture\n", | |
| "import os\n", | |
| "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n", | |
| " !pip install unsloth vllm\n", | |
| "else:\n", | |
| " !pip install --no-deps unsloth vllm==0.8.5.post1\n", | |
| " # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n", | |
| " # Skip restarting message in Colab\n", | |
| " import sys, re, requests; modules = list(sys.modules.keys())\n", | |
| " for x in modules: sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n", | |
| " !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft \"trl==0.15.2\" triton cut_cross_entropy unsloth_zoo\n", | |
| " !pip install sentencepiece protobuf \"datasets>=3.4.1\" huggingface_hub hf_transfer\n", | |
| " !pip install transformers==4.51.3\n", | |
| "\n", | |
| " # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n", | |
| " f = requests.get(\"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\").content\n", | |
| " with open(\"vllm_requirements.txt\", \"wb\") as file:\n", | |
| " file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n", | |
| " !pip install -r vllm_requirements.txt" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "0kisw2lPYvu6" | |
| }, | |
| "source": [ | |
| "### Load Unsloth Model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "Joje4qPsyxM9" | |
| }, | |
| "source": [ | |
| "Load up `Qwen 2.5 3B Instruct`, and set parameters" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "DkIvEkIIkEyB" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from unsloth import FastLanguageModel, is_bfloat16_supported\n", | |
| "import torch\n", | |
| "max_seq_length = 1024 # Can increase for longer reasoning traces\n", | |
| "lora_rank = 64 # Larger rank = smarter, but slower\n", | |
| "\n", | |
| "model, tokenizer = FastLanguageModel.from_pretrained(\n", | |
| " model_name = \"Qwen/Qwen2.5-3B-Instruct\",\n", | |
| " max_seq_length = max_seq_length,\n", | |
| " load_in_4bit = True, # False for LoRA 16bit\n", | |
| " fast_inference = True, # Enable vLLM fast inference\n", | |
| " max_lora_rank = lora_rank,\n", | |
| " gpu_memory_utilization = 0.5, # Reduce if out of memory\n", | |
| ")\n", | |
| "\n", | |
| "model = FastLanguageModel.get_peft_model(\n", | |
| " model,\n", | |
| " r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n", | |
| " target_modules = [\n", | |
| " \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n", | |
| " \"gate_proj\", \"up_proj\", \"down_proj\",\n", | |
| " ], # Remove QKVO if out of memory\n", | |
| " lora_alpha = lora_rank,\n", | |
| " use_gradient_checkpointing = \"unsloth\", # Enable long context finetuning\n", | |
| " random_state = 3407,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "0Y56ln_izS9E" | |
| }, | |
| "source": [ | |
| "### Data Preparation and PI Reward Functions" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "cXk993X6C2ZZ" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from datasets import load_dataset, Dataset\n", | |
| "import requests\n", | |
| "\n", | |
| "# Load and prep dataset\n", | |
| "SYSTEM_PROMPT = \"\"\"\n", | |
| "Given a menu item, predict its macronutrient values (protein,\n", | |
| "carbs, fats).\n", | |
| "\"\"\"\n", | |
| "\n", | |
| "# We added your Hugging Face data below\n", | |
| "dataset = load_dataset(\"jdnvn/menu-items-allmenus\", split=\"train\")\n", | |
| "dataset = dataset.rename_column(\"description\", \"input\")\n", | |
| "dataset = dataset.select(range(500))\n", | |
| "dataset = dataset.map(\n", | |
| " lambda x: {\n", | |
| " \"prompt\": [\n", | |
| " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", | |
| " {\"role\": \"user\", \"content\": str(x[\"input\"])},\n", | |
| " ]\n", | |
| " }\n", | |
| ")\n", | |
| "print(dataset[0])\n", | |
| "\n", | |
| "\n", | |
| "# Pi constants\n", | |
| "PI_API_URL = \"https://api.withpi.ai/v1/scoring_system/score\"\n", | |
| "HEADERS = {\n", | |
| " \"Content-Type\": \"application/json\",\n", | |
| " \"x-api-key\": os.environ.get(\"WITHPI_API_KEY\"),\n", | |
| "}\n", | |
| "\n", | |
| "# Pi util functions\n", | |
| "def get_pi_score(input: str, output: str, question: str) -> float:\n", | |
| " payload = {\n", | |
| " \"llm_input\": input,\n", | |
| " \"llm_output\": output,\n", | |
| " \"scoring_spec\": [{\"question\": question}]\n", | |
| " }\n", | |
| " # Can add retry if needed.\n", | |
| " response = requests.post(PI_API_URL, headers=HEADERS, json=payload)\n", | |
| " return response.json()[\"total_score\"]\n", | |
| "\n", | |
| "def get_scores(prompts, completions, question: str) -> list[float]:\n", | |
| " inputs = [prompt[-1][\"content\"] for prompt in prompts]\n", | |
| " outputs = [completion[0][\"content\"] for completion in completions]\n", | |
| " return [get_pi_score(input, output, question) for input, output in zip(inputs, outputs)]\n", | |
| "\n", | |
| "# Reward functions\n", | |
| "def pi_accuracy_of_macros(prompts, completions, **kwargs) -> list[float]:\n", | |
| " return get_scores(prompts, completions, \"Does the response provide accurate macronutrient values (e.g., protein, carbs, fats) for the given menu item?\")\n", | |
| "\n", | |
| "def pi_relevance_to_item(prompts, completions, **kwargs) -> list[float]:\n", | |
| " return get_scores(prompts, completions, \"Are the predicted macronutrient values relevant to the specific menu item provided in the input?\")\n", | |
| "\n", | |
| "def pi_clarity_of_output(prompts, completions, **kwargs) -> list[float]:\n", | |
| " return get_scores(prompts, completions, \"Is the presentation of the macronutrient values clear and easy to understand?\")\n", | |
| "\n", | |
| "def pi_unit_consistency(prompts, completions, **kwargs) -> list[float]:\n", | |
| " return get_scores(prompts, completions, \"Are the macronutrient values expressed in consistent and standard units (e.g., grams)?\")\n", | |
| "\n", | |
| "def pi_completeness_of_macros(prompts, completions, **kwargs) -> list[float]:\n", | |
| " return get_scores(prompts, completions, \"Does the response include all three macronutrient categories: protein, carbs, and fats?\")\n", | |
| "\n", | |
| "def pi_avoids_extraneous_data(prompts, completions, **kwargs) -> list[float]:\n", | |
| " return get_scores(prompts, completions, \"Does the response avoid including irrelevant or extraneous information not related to macronutrient values?\")\n", | |
| "\n", | |
| "def pi_numerical_precision(prompts, completions, **kwargs) -> list[float]:\n", | |
| " return get_scores(prompts, completions, \"Are the macronutrient values provided with appropriate numerical precision (e.g., up to one decimal place if necessary)?\")\n", | |
| "\n", | |
| "def pi_input_alignment(prompts, completions, **kwargs) -> list[float]:\n", | |
| " return get_scores(prompts, completions, \"Does the response align with the specific details of the input menu item (e.g., portion size, preparation method)?\")\n", | |
| "\n", | |
| "def pi_absence_of_errors(prompts, completions, **kwargs) -> list[float]:\n", | |
| " return get_scores(prompts, completions, \"Is the response free from numerical or typographical errors in the macronutrient values?\")\n", | |
| "\n", | |
| "def pi_output_structure(prompts, completions, **kwargs) -> list[float]:\n", | |
| " return get_scores(prompts, completions, \"Is the output structured in a way that clearly separates the macronutrient categories?\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "bTnL_tJnzh2L" | |
| }, | |
| "source": [ | |
| "<a name=\"Train\"></a>\n", | |
| "### Train the model\n", | |
| "\n", | |
| "Now set up GRPO Trainer and all configurations!" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "ptqkXK2D4d6p" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from trl import GRPOConfig, GRPOTrainer\n", | |
| "training_args = GRPOConfig(\n", | |
| " use_vllm = True, # use vLLM for fast inference!\n", | |
| " learning_rate = 5e-6,\n", | |
| " adam_beta1 = 0.9,\n", | |
| " adam_beta2 = 0.99,\n", | |
| " weight_decay = 0.1,\n", | |
| " warmup_ratio = 0.1,\n", | |
| " lr_scheduler_type = \"cosine\",\n", | |
| " optim = \"adamw_8bit\",\n", | |
| " logging_steps = 1,\n", | |
| " bf16 = is_bfloat16_supported(),\n", | |
| " fp16 = not is_bfloat16_supported(),\n", | |
| " per_device_train_batch_size = 1,\n", | |
| " gradient_accumulation_steps = 1, # Increase to 4 for smoother training\n", | |
| " num_generations = 8, # Decrease if out of memory\n", | |
| " max_prompt_length = 1024,\n", | |
| " max_completion_length = 200,\n", | |
| " # num_train_epochs = 1, # Set to 1 for a full training run\n", | |
| " max_steps = 50,\n", | |
| " save_steps = 50,\n", | |
| " max_grad_norm = 0.1,\n", | |
| " report_to = \"none\", # Can use Weights & Biases\n", | |
| " output_dir = \"outputs\",\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "vzOuSVCL_GA9" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "trainer = GRPOTrainer(\n", | |
| " model = model,\n", | |
| " processing_class = tokenizer,\n", | |
| " reward_funcs = [\n", | |
| " pi_accuracy_of_macros,\n", | |
| " pi_relevance_to_item,\n", | |
| " pi_clarity_of_output,\n", | |
| " pi_unit_consistency,\n", | |
| " pi_completeness_of_macros,\n", | |
| " pi_avoids_extraneous_data,\n", | |
| " pi_numerical_precision,\n", | |
| " pi_input_alignment,\n", | |
| " pi_absence_of_errors,\n", | |
| " pi_output_structure,\n", | |
| " ],\n", | |
| " args = training_args,\n", | |
| " train_dataset = dataset,\n", | |
| ")\n", | |
| "trainer.train()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "yUbluAAhD0Lg" | |
| }, | |
| "source": [ | |
| "<a name=\"Inference\"></a>\n", | |
| "### Inference\n", | |
| "Now let's try the model we just trained! First, let's first try the model without any GRPO trained:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "IqzsdZzeDM_m" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from vllm import SamplingParams\n", | |
| "inputs = [\n", | |
| " \"\"\"Grilled Chicken Salad with Ranch Dressing\"\"\",\n", | |
| " \"\"\"Double Cheeseburger with Fries\"\"\",\n", | |
| " \"\"\"Vegan Buddha Bowl with Quinoa and Avocado\"\"\",\n", | |
| " \"\"\"Pepperoni Pizza Slice\"\"\",\n", | |
| " \"\"\"Chocolate Milkshake\"\"\",\n", | |
| "]\n", | |
| "for input in inputs:\n", | |
| " text = tokenizer.apply_chat_template(\n", | |
| " [\n", | |
| " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", | |
| " {\"role\": \"user\", \"content\": input},\n", | |
| " ],\n", | |
| " tokenize=False,\n", | |
| " add_generation_prompt=True,\n", | |
| " )\n", | |
| "\n", | |
| "\n", | |
| " sampling_params = SamplingParams(\n", | |
| " temperature=0.4,\n", | |
| " top_p=0.95,\n", | |
| " max_tokens=1024,\n", | |
| " )\n", | |
| " output = (\n", | |
| " model.fast_generate(\n", | |
| " [text],\n", | |
| " sampling_params=sampling_params,\n", | |
| " lora_request=None,\n", | |
| " )[0]\n", | |
| " .outputs[0]\n", | |
| " .text\n", | |
| " )\n", | |
| " print(\"INPUT:\", input)\n", | |
| " print(\"OUTPUT\", \"\\n\", output)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "ICDjepxv4D0M" | |
| }, | |
| "source": [ | |
| "Now we load the LoRA and test:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "jK7sz95932CO" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "model.save_lora(\"grpo_saved_lora\")\n", | |
| "for input in inputs:\n", | |
| " text = tokenizer.apply_chat_template(\n", | |
| " [\n", | |
| " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", | |
| " {\"role\": \"user\", \"content\": input},\n", | |
| " ],\n", | |
| " tokenize=False,\n", | |
| " add_generation_prompt=True,\n", | |
| " )\n", | |
| "\n", | |
| "\n", | |
| " sampling_params = SamplingParams(\n", | |
| " temperature=0.4,\n", | |
| " top_p=0.95,\n", | |
| " max_tokens=1024,\n", | |
| " )\n", | |
| " output = (\n", | |
| " model.fast_generate(\n", | |
| " [text],\n", | |
| " sampling_params=sampling_params,\n", | |
| " lora_request=None,\n", | |
| " )[0]\n", | |
| " .outputs[0]\n", | |
| " .text\n", | |
| " )\n", | |
| " print(\"INPUT:\", input)\n", | |
| " print(\"OUTPUT\", \"\\n\", output)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "accelerator": "GPU", | |
| "colab": { | |
| "gpuType": "T4", | |
| "provenance": [], | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment