Skip to content

Instantly share code, notes, and snippets.

@suneel-pi
Created September 16, 2025 06:15
Show Gist options
  • Save suneel-pi/811ae833c8a3597dee7808a23176d1c5 to your computer and use it in GitHub Desktop.
Save suneel-pi/811ae833c8a3597dee7808a23176d1c5 to your computer and use it in GitHub Desktop.
pi-scorer-on-unsloth.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/suneel-pi/811ae833c8a3597dee7808a23176d1c5/pi-scorer-on-unsloth.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pi-masthead"
},
"source": [
"<a href=\"https://withpi.ai\"><img src=\"https://play.withpi.ai/logo/logoFullBlack.svg\" width=\"240\"></a>\n",
"\n",
"<a href=\"https://code.withpi.ai\"><font size=\"4\">Documentation</font></a>\n",
"\n",
"<a href=\"https://build.withpi.ai\"><font size=\"4\">Copilot</font></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KiS-7CbBYS0D"
},
"source": [
"Are you constantly relying on LLM-as-a-judge to evaluate your model’s performance?\n",
"\n",
"Have you ever wanted to assess your model at every training checkpoint but hesitated because LLM-as-a-judge is too slow and expensive?\n",
"\n",
"**Now you can — with [Pi-Scorer](https://build.withpi.ai).**\n",
"\n",
"[Pi-Scorer](https://build.withpi.ai) offers an alternative to LLM-as-a-judge with several advantages:\n",
"\n",
"* Significantly faster\n",
"\n",
"* Highly consistent — always returns the same score for the same inputs\n",
"\n",
"* Eliminates the need for prompt tuning or adjustments\n",
"\n",
"In this Colab, we integrate [Pi-Scorer](https://build.withpi.ai) as the reward function within the [Unsloth](https://unsloth.ai/) GRPO training loop, based on the [Unsloth Qwen2.5_(3B)-GRPO.ipynb colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_(3B)-GRPO.ipynb) notebook."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NVoflExGYvu6"
},
"source": [
"### Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cuA79HQOBfWX"
},
"outputs": [],
"source": [
"from google.colab import userdata\n",
"import os\n",
"\n",
"# Get PI API key: https://build.withpi.ai/account/keys\n",
"os.environ[\"WITHPI_API_KEY\"] = userdata.get('WITHPI_API_KEY')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pfP5p84HYvu6"
},
"outputs": [],
"source": [
"%%capture\n",
"import os\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
" !pip install unsloth vllm\n",
"else:\n",
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
" !pip install --no-deps unsloth vllm==0.8.5.post1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aC-VIBYpYvu6"
},
"outputs": [],
"source": [
"#@title Colab Extra Install { display-mode: \"form\" }\n",
"%%capture\n",
"import os\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
" !pip install unsloth vllm\n",
"else:\n",
" !pip install --no-deps unsloth vllm==0.8.5.post1\n",
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
" # Skip restarting message in Colab\n",
" import sys, re, requests; modules = list(sys.modules.keys())\n",
" for x in modules: sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n",
" !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft \"trl==0.15.2\" triton cut_cross_entropy unsloth_zoo\n",
" !pip install sentencepiece protobuf \"datasets>=3.4.1\" huggingface_hub hf_transfer\n",
" !pip install transformers==4.51.3\n",
"\n",
" # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n",
" f = requests.get(\"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\").content\n",
" with open(\"vllm_requirements.txt\", \"wb\") as file:\n",
" file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n",
" !pip install -r vllm_requirements.txt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0kisw2lPYvu6"
},
"source": [
"### Load Unsloth Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Joje4qPsyxM9"
},
"source": [
"Load up `Qwen 2.5 3B Instruct`, and set parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DkIvEkIIkEyB"
},
"outputs": [],
"source": [
"from unsloth import FastLanguageModel, is_bfloat16_supported\n",
"import torch\n",
"max_seq_length = 1024 # Can increase for longer reasoning traces\n",
"lora_rank = 64 # Larger rank = smarter, but slower\n",
"\n",
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
" model_name = \"Qwen/Qwen2.5-3B-Instruct\",\n",
" max_seq_length = max_seq_length,\n",
" load_in_4bit = True, # False for LoRA 16bit\n",
" fast_inference = True, # Enable vLLM fast inference\n",
" max_lora_rank = lora_rank,\n",
" gpu_memory_utilization = 0.5, # Reduce if out of memory\n",
")\n",
"\n",
"model = FastLanguageModel.get_peft_model(\n",
" model,\n",
" r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
" target_modules = [\n",
" \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
" \"gate_proj\", \"up_proj\", \"down_proj\",\n",
" ], # Remove QKVO if out of memory\n",
" lora_alpha = lora_rank,\n",
" use_gradient_checkpointing = \"unsloth\", # Enable long context finetuning\n",
" random_state = 3407,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Y56ln_izS9E"
},
"source": [
"### Data Preparation and PI Reward Functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cXk993X6C2ZZ"
},
"outputs": [],
"source": [
"from datasets import load_dataset, Dataset\n",
"import requests\n",
"\n",
"# Load and prep dataset\n",
"SYSTEM_PROMPT = \"\"\"\n",
"Given a menu item, predict its macronutrient values (protein,\n",
"carbs, fats).\n",
"\"\"\"\n",
"\n",
"# We added your Hugging Face data below\n",
"dataset = load_dataset(\"jdnvn/menu-items-allmenus\", split=\"train\")\n",
"dataset = dataset.rename_column(\"description\", \"input\")\n",
"dataset = dataset.select(range(500))\n",
"dataset = dataset.map(\n",
" lambda x: {\n",
" \"prompt\": [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": str(x[\"input\"])},\n",
" ]\n",
" }\n",
")\n",
"print(dataset[0])\n",
"\n",
"\n",
"# Pi constants\n",
"PI_API_URL = \"https://api.withpi.ai/v1/scoring_system/score\"\n",
"HEADERS = {\n",
" \"Content-Type\": \"application/json\",\n",
" \"x-api-key\": os.environ.get(\"WITHPI_API_KEY\"),\n",
"}\n",
"\n",
"# Pi util functions\n",
"def get_pi_score(input: str, output: str, question: str) -> float:\n",
" payload = {\n",
" \"llm_input\": input,\n",
" \"llm_output\": output,\n",
" \"scoring_spec\": [{\"question\": question}]\n",
" }\n",
" # Can add retry if needed.\n",
" response = requests.post(PI_API_URL, headers=HEADERS, json=payload)\n",
" return response.json()[\"total_score\"]\n",
"\n",
"def get_scores(prompts, completions, question: str) -> list[float]:\n",
" inputs = [prompt[-1][\"content\"] for prompt in prompts]\n",
" outputs = [completion[0][\"content\"] for completion in completions]\n",
" return [get_pi_score(input, output, question) for input, output in zip(inputs, outputs)]\n",
"\n",
"# Reward functions\n",
"def pi_accuracy_of_macros(prompts, completions, **kwargs) -> list[float]:\n",
" return get_scores(prompts, completions, \"Does the response provide accurate macronutrient values (e.g., protein, carbs, fats) for the given menu item?\")\n",
"\n",
"def pi_relevance_to_item(prompts, completions, **kwargs) -> list[float]:\n",
" return get_scores(prompts, completions, \"Are the predicted macronutrient values relevant to the specific menu item provided in the input?\")\n",
"\n",
"def pi_clarity_of_output(prompts, completions, **kwargs) -> list[float]:\n",
" return get_scores(prompts, completions, \"Is the presentation of the macronutrient values clear and easy to understand?\")\n",
"\n",
"def pi_unit_consistency(prompts, completions, **kwargs) -> list[float]:\n",
" return get_scores(prompts, completions, \"Are the macronutrient values expressed in consistent and standard units (e.g., grams)?\")\n",
"\n",
"def pi_completeness_of_macros(prompts, completions, **kwargs) -> list[float]:\n",
" return get_scores(prompts, completions, \"Does the response include all three macronutrient categories: protein, carbs, and fats?\")\n",
"\n",
"def pi_avoids_extraneous_data(prompts, completions, **kwargs) -> list[float]:\n",
" return get_scores(prompts, completions, \"Does the response avoid including irrelevant or extraneous information not related to macronutrient values?\")\n",
"\n",
"def pi_numerical_precision(prompts, completions, **kwargs) -> list[float]:\n",
" return get_scores(prompts, completions, \"Are the macronutrient values provided with appropriate numerical precision (e.g., up to one decimal place if necessary)?\")\n",
"\n",
"def pi_input_alignment(prompts, completions, **kwargs) -> list[float]:\n",
" return get_scores(prompts, completions, \"Does the response align with the specific details of the input menu item (e.g., portion size, preparation method)?\")\n",
"\n",
"def pi_absence_of_errors(prompts, completions, **kwargs) -> list[float]:\n",
" return get_scores(prompts, completions, \"Is the response free from numerical or typographical errors in the macronutrient values?\")\n",
"\n",
"def pi_output_structure(prompts, completions, **kwargs) -> list[float]:\n",
" return get_scores(prompts, completions, \"Is the output structured in a way that clearly separates the macronutrient categories?\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bTnL_tJnzh2L"
},
"source": [
"<a name=\"Train\"></a>\n",
"### Train the model\n",
"\n",
"Now set up GRPO Trainer and all configurations!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ptqkXK2D4d6p"
},
"outputs": [],
"source": [
"from trl import GRPOConfig, GRPOTrainer\n",
"training_args = GRPOConfig(\n",
" use_vllm = True, # use vLLM for fast inference!\n",
" learning_rate = 5e-6,\n",
" adam_beta1 = 0.9,\n",
" adam_beta2 = 0.99,\n",
" weight_decay = 0.1,\n",
" warmup_ratio = 0.1,\n",
" lr_scheduler_type = \"cosine\",\n",
" optim = \"adamw_8bit\",\n",
" logging_steps = 1,\n",
" bf16 = is_bfloat16_supported(),\n",
" fp16 = not is_bfloat16_supported(),\n",
" per_device_train_batch_size = 1,\n",
" gradient_accumulation_steps = 1, # Increase to 4 for smoother training\n",
" num_generations = 8, # Decrease if out of memory\n",
" max_prompt_length = 1024,\n",
" max_completion_length = 200,\n",
" # num_train_epochs = 1, # Set to 1 for a full training run\n",
" max_steps = 50,\n",
" save_steps = 50,\n",
" max_grad_norm = 0.1,\n",
" report_to = \"none\", # Can use Weights & Biases\n",
" output_dir = \"outputs\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vzOuSVCL_GA9"
},
"outputs": [],
"source": [
"trainer = GRPOTrainer(\n",
" model = model,\n",
" processing_class = tokenizer,\n",
" reward_funcs = [\n",
" pi_accuracy_of_macros,\n",
" pi_relevance_to_item,\n",
" pi_clarity_of_output,\n",
" pi_unit_consistency,\n",
" pi_completeness_of_macros,\n",
" pi_avoids_extraneous_data,\n",
" pi_numerical_precision,\n",
" pi_input_alignment,\n",
" pi_absence_of_errors,\n",
" pi_output_structure,\n",
" ],\n",
" args = training_args,\n",
" train_dataset = dataset,\n",
")\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yUbluAAhD0Lg"
},
"source": [
"<a name=\"Inference\"></a>\n",
"### Inference\n",
"Now let's try the model we just trained! First, let's first try the model without any GRPO trained:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IqzsdZzeDM_m"
},
"outputs": [],
"source": [
"from vllm import SamplingParams\n",
"inputs = [\n",
" \"\"\"Grilled Chicken Salad with Ranch Dressing\"\"\",\n",
" \"\"\"Double Cheeseburger with Fries\"\"\",\n",
" \"\"\"Vegan Buddha Bowl with Quinoa and Avocado\"\"\",\n",
" \"\"\"Pepperoni Pizza Slice\"\"\",\n",
" \"\"\"Chocolate Milkshake\"\"\",\n",
"]\n",
"for input in inputs:\n",
" text = tokenizer.apply_chat_template(\n",
" [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": input},\n",
" ],\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
" )\n",
"\n",
"\n",
" sampling_params = SamplingParams(\n",
" temperature=0.4,\n",
" top_p=0.95,\n",
" max_tokens=1024,\n",
" )\n",
" output = (\n",
" model.fast_generate(\n",
" [text],\n",
" sampling_params=sampling_params,\n",
" lora_request=None,\n",
" )[0]\n",
" .outputs[0]\n",
" .text\n",
" )\n",
" print(\"INPUT:\", input)\n",
" print(\"OUTPUT\", \"\\n\", output)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ICDjepxv4D0M"
},
"source": [
"Now we load the LoRA and test:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jK7sz95932CO"
},
"outputs": [],
"source": [
"model.save_lora(\"grpo_saved_lora\")\n",
"for input in inputs:\n",
" text = tokenizer.apply_chat_template(\n",
" [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": input},\n",
" ],\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
" )\n",
"\n",
"\n",
" sampling_params = SamplingParams(\n",
" temperature=0.4,\n",
" top_p=0.95,\n",
" max_tokens=1024,\n",
" )\n",
" output = (\n",
" model.fast_generate(\n",
" [text],\n",
" sampling_params=sampling_params,\n",
" lora_request=None,\n",
" )[0]\n",
" .outputs[0]\n",
" .text\n",
" )\n",
" print(\"INPUT:\", input)\n",
" print(\"OUTPUT\", \"\\n\", output)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment