Skip to content

Instantly share code, notes, and snippets.

@willccbb
Last active February 19, 2025 03:30
Show Gist options
  • Save willccbb/4676755236bb08cab5f4e54a0475d6fb to your computer and use it in GitHub Desktop.
Save willccbb/4676755236bb08cab5f4e54a0475d6fb to your computer and use it in GitHub Desktop.
GRPO Llama-1B
# train_grpo.py
#
# See https://github.com/willccbb/verifiers for ongoing developments
#
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip().replace(",", "").replace("$", "")
# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
data = data.map(lambda x: { # type: ignore
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
#{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
#{'role': 'assistant', 'content': XML_COT_FORMAT.format(
# reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
# answer="7"
#)},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
}) # type: ignore
return data # type: ignore
dataset = get_gsm8k_questions()
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
#model_name = "meta-llama/Llama-3.2-1B-Instruct"
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
if "Llama" in model_name:
output_dir = "outputs/Llama-1B-GRPO"
run_name = "Llama-1B-GRPO-gsm8k"
else:
output_dir="outputs/Qwen-1.5B-GRPO"
run_name="Qwen-1.5B-GRPO-gsm8k"
training_args = GRPOConfig(
output_dir=output_dir,
run_name=run_name,
learning_rate=5e-6,
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type='cosine',
logging_steps=1,
bf16=True,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=16,
max_prompt_length=256,
max_completion_length=786,
num_train_epochs=1,
save_steps=100,
max_grad_norm=0.1,
report_to="wandb",
log_on_each_node=False,
)
peft_config = LoraConfig(
r=16,
lora_alpha=64,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
task_type="CAUSAL_LM",
lora_dropout=0.05,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map=None
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# use peft at your own risk; not working for me with multi-GPU training
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func],
args=training_args,
train_dataset=dataset,
#peft_config=peft_config
)
trainer.train()
@s-smits
Copy link

s-smits commented Feb 3, 2025

Hey everyone,

I've been seeing the popularity of this script and tried to further improve it with an idea I wanted to share and get some feedback on. It seems like fine-tuning the reward values can be a bit of a manual process, and I was wondering if we could automate it using Optuna.

import re
import torch
import optuna
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

# --- Load and prep dataset ---

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
            #{'role': 'assistant', 'content': XML_COT_FORMAT.format(
            #    reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
            #    answer="7"
            #)},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions(split="train")
# Use the test set for evaluation
validation_dataset = get_gsm8k_questions(split="test")

# --- Reward functions with dynamic return values ---

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    correct_reward = kwargs.get("correct_reward", 2.0)
    incorrect_reward = kwargs.get("incorrect_reward", 0.0)
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [correct_reward if r == a else incorrect_reward for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    int_reward = kwargs.get("int_reward", 0.5)
    non_int_reward = kwargs.get("non_int_reward", 0.0)
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [int_reward if r.isdigit() else non_int_reward for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    match_reward = kwargs.get("strict_match_reward", 0.5)
    no_match_reward = kwargs.get("strict_no_match_reward", 0.0)
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [match_reward if match else no_match_reward for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    match_reward = kwargs.get("soft_match_reward", 0.5)
    no_match_reward = kwargs.get("soft_no_match_reward", 0.0)
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [match_reward if match else no_match_reward for match in matches]

def count_xml(text, xml_count_reward=0.125) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += xml_count_reward
    if text.count("\n</reasoning>\n") == 1:
        count += xml_count_reward
    if text.count("\n<answer>\n") == 1:
        count += xml_count_reward
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += xml_count_reward
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    xml_count_reward = kwargs.get("xml_count_reward", 0.125)
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c, xml_count_reward) for c in contents]

# --- Model and Config ---

#model_name = "meta-llama/Llama-3.2-1B-Instruct"
model_name = "Qwen/Qwen2.5-1.5B-Instruct"

if "Llama" in model_name:
    output_dir = "outputs/Llama-1B-GRPO"
    run_name = "Llama-1B-GRPO-gsm8k"
else:
    output_dir="outputs/Qwen-1.5B-GRPO"
    run_name="Qwen-1.5B-GRPO-gsm8k"

# --- Optuna Objective Function ---

def objective(trial):
    # Define hyperparameters to optimize
    learning_rate = trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True)
    weight_decay = trial.suggest_float("weight_decay", 0.0, 0.2)
    warmup_ratio = trial.suggest_float("warmup_ratio", 0.0, 0.2)
    max_grad_norm = trial.suggest_float("max_grad_norm", 0.1, 1.0)

    # Dynamic reward function weights
    xmlcount_weight = trial.suggest_float("xmlcount_weight", 0.5, 1.5)
    soft_format_weight = trial.suggest_float("soft_format_weight", 0.5, 1.5)
    strict_format_weight = trial.suggest_float("strict_format_weight", 0.5, 1.5)
    int_weight = trial.suggest_float("int_weight", 0.5, 1.5)
    correctness_weight = trial.suggest_float("correctness_weight", 1.0, 3.0)

    # Dynamic reward function parameters
    correct_reward = trial.suggest_float("correct_reward", 1.0, 3.0)
    incorrect_reward = trial.suggest_float("incorrect_reward", -1.0, 0.0)
    int_reward = trial.suggest_float("int_reward", 0.1, 1.0)
    non_int_reward = trial.suggest_float("non_int_reward", -0.5, 0.0)
    strict_match_reward = trial.suggest_float("strict_match_reward", 0.1, 1.0)
    strict_no_match_reward = trial.suggest_float("strict_no_match_reward", -0.5, 0.0)
    soft_match_reward = trial.suggest_float("soft_match_reward", 0.1, 1.0)
    soft_no_match_reward = trial.suggest_float("soft_no_match_reward", -0.5, 0.0)
    xml_count_reward = trial.suggest_float("xml_count_reward", 0.05, 0.2)

    # Define GRPOConfig with optimized hyperparameters
    training_args = GRPOConfig(
        output_dir=output_dir,
        run_name=run_name,
        learning_rate=learning_rate,
        adam_beta1 = 0.9,
        adam_beta2 = 0.99,
        weight_decay = weight_decay,
        warmup_ratio = warmup_ratio,
        lr_scheduler_type='cosine',
        logging_steps=1,
        bf16=True,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        num_generations=16,
        max_prompt_length=256,
        max_completion_length=786,
        num_train_epochs=1,
        save_steps=100,
        max_grad_norm=max_grad_norm,
        report_to="wandb",
        log_on_each_node=False,
    )

    # Define LoRA config (optional)
    peft_config = LoraConfig(
        r=16,
        lora_alpha=64,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
        task_type="CAUSAL_LM",
        lora_dropout=0.05,
    )

    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        device_map=None
    ).to("cuda")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    # Define GRPOTrainer with dynamic reward functions
    def weighted_xmlcount_reward_func(completions, **kwargs):
      return [x * xmlcount_weight for x in xmlcount_reward_func(completions, xml_count_reward=xml_count_reward, **kwargs)]

    def weighted_soft_format_reward_func(completions, **kwargs):
      return [x * soft_format_weight for x in soft_format_reward_func(completions, soft_match_reward=soft_match_reward, soft_no_match_reward=soft_no_match_reward, **kwargs)]

    def weighted_strict_format_reward_func(completions, **kwargs):
      return [x * strict_format_weight for x in strict_format_reward_func(completions, strict_match_reward=strict_match_reward, strict_no_match_reward=strict_no_match_reward, **kwargs)]

    def weighted_int_reward_func(completions, **kwargs):
      return [x * int_weight for x in int_reward_func(completions, int_reward=int_reward, non_int_reward=non_int_reward, **kwargs)]

    def weighted_correctness_reward_func(prompts, completions, answer, **kwargs):
      return [x * correctness_weight for x in correctness_reward_func(prompts, completions, answer, correct_reward=correct_reward, incorrect_reward=incorrect_reward, **kwargs)]

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[
            weighted_xmlcount_reward_func,
            weighted_soft_format_reward_func,
            weighted_strict_format_reward_func,
            weighted_int_reward_func,
            weighted_correctness_reward_func
        ],
        args=training_args,
        train_dataset=dataset,
        #peft_config=peft_config
    )

    # Train the model
    trainer.train()

    # --- Evaluation on the GSM8K test set ---
    def evaluate_model(model, tokenizer, dataset):
        model.eval()
        correct_predictions = 0
        total_predictions = 0

        for i in range(len(dataset)):
            sample = dataset[i]
            prompt = sample['prompt']
            true_answer = sample['answer']

            inputs = tokenizer.apply_chat_template(prompt, return_tensors="pt").to("cuda")
            with torch.no_grad():
                outputs = model.generate(inputs, max_new_tokens=786, pad_token_id=tokenizer.eos_token_id)

            generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            predicted_answer = extract_xml_answer(generated_text)

            if predicted_answer == true_answer:
                correct_predictions += 1
            total_predictions += 1

        accuracy = correct_predictions / total_predictions
        return accuracy

    accuracy = evaluate_model(model, tokenizer, validation_dataset)
    print(f"Accuracy: {accuracy}")
    return accuracy

# --- Optuna Study Setup ---

study = optuna.create_study(direction="maximize") # Maximize accuracy

# Initial parameter suggestions
initial_params = {
    "learning_rate": 5e-6,
    "weight_decay": 0.1,
    "warmup_ratio": 0.1,
    "max_grad_norm": 0.1,
    "xmlcount_weight": 1.0,
    "soft_format_weight": 1.0,
    "strict_format_weight": 1.0,
    "int_weight": 1.0,
    "correctness_weight": 2.0,
    "correct_reward": 2.0,
    "incorrect_reward": 0.0,
    "int_reward": 0.5,
    "non_int_reward": 0.0,
    "strict_match_reward": 0.5,
    "strict_no_match_reward": 0.0,
    "soft_match_reward": 0.5,
    "soft_no_match_reward": 0.0,
    "xml_count_reward": 0.125
}

# Run initial trial
study.enqueue_trial(initial_params)
study.optimize(objective, n_trials=1)

# Run remaining trials with free optimization
study.optimize(objective, n_trials=9) # 9 more trials for a total of 10

# Print best hyperparameters
print("Best hyperparameters:", study.best_trial.params)

Basically, it sets up the reward functions to accept dynamic values, controlled by Optuna. Optuna then tweaks these values during training, searching for the combination that leads to the best performance on the GSM8K test set.

I haven't actually tested this yet, so it's still a rough overview of how it could look. Just wanted to share the idea. Let me know what you think! Any thoughts or suggestions are greatly appreciated!

@dignfei
Copy link

dignfei commented Feb 4, 2025

为什么会出现越来越多的(\boxed{ANSWE})? 代码对于这个没有任何奖励,为什么训练过程中这个越来越多?

@HaileyStorm
Copy link

HaileyStorm commented Feb 4, 2025

So... what the heck.

Realization!

7900XTX pip freeze: trl @ git+github.com/huggingface/trl.git@2ce36ae889f286dad91dc3ac6b55904864bf9254 H100 pip freeze: trl @ git+github.com/huggingface/trl.git@1c35a48b50f54b92c6b820437aaf75c4e3d777ce

See also x.com/abacaj/status/1885856348505616771?t=s6Cm6ahCx_hXvm_jNEENlw&s=19 by @abacaj , which I think might be the same thing.

Trying to confirm this is the issue I was having, but no H100s where I have credit at the moment.

Ping @ianand and (sorry if this is wrong call) @willccbb .

image

I was wrong, my issue was with how trl is calculating the GRPO update. Note: No vllm here.

@handrew
Copy link

handrew commented Feb 5, 2025

Thanks for the work. Just FYI to those looking on, from the Huggingface page

"Avoid adding a system prompt; all instructions should be contained within the user prompt."

@andrewsiah
Copy link

andrewsiah commented Feb 5, 2025

How do you all add extra metrics to log onto wandb when running the file with accelerate launch?

e.g. In the correctness_reward_func, I want to add some metrics to log on wandb.

But it keeps crashing. What is the right way to get wandb that is automatically in trainer to log onto after?


def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]["content"] for completion in completions]
    q = prompts[0][-1]["content"]
    extracted_responses = [extract_xml_answer(r) for r in responses]

    # Get current step from trainer's state
    current_step = trainer.state.global_step if hasattr(trainer, "state") else 0

    # Initialize logger if not already done
    global example_logger
    if not hasattr(correctness_reward_func, "example_logger"):
        example_logger = LocalExampleLogger()
        correctness_reward_func.example_logger = example_logger

    # Log each example
    for i in range(len(responses)):
        example_dict = {
            "step": current_step,
            "question": q,
            "true_answer": answer[i],
            "response": responses[i],
            "extracted_response": extracted_responses[i],
            "correct": extracted_responses[i] == answer[i],
            "generation_idx": i,  # Which generation attempt this was
        }
        example_logger.log_example(example_dict)

    # Calculate marker counts and correctness for all responses
    is_correct = [r == a for r, a in zip(extracted_responses, answer)]
    uncertainty_counts = [count_uncertainty_markers(r) for r in responses]
    internal_dialogue_counts = [count_internal_dialogue_markers(r) for r in responses]
    reflective_counts = [count_reflective_markers(r) for r in responses]

    # Separate counts for correct and incorrect responses
    correct_indices = [i for i, correct in enumerate(is_correct) if correct]
    incorrect_indices = [i for i, correct in enumerate(is_correct) if not correct]

    # Log metrics using trainer's accelerator
    if hasattr(trainer, "accelerator"):
        # EXAMPLE EXTRA METRIC HERE
        metrics = {
            "correctness/correct_count": len(correct_indices)
        }
        trainer.accelerator.log(metrics, step=current_step)

    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
    ```
    
    

@ianand
Copy link

ianand commented Feb 5, 2025

@HaileyStorm re: https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb?permalink_comment_id=5423553#gistcomment-5423553

I was wrong, my issue was with how trl is calculating the GRPO update. Note: No vllm here.

Can you elaborate on how you determined that?

@ianand
Copy link

ianand commented Feb 5, 2025

@andrewsiah
Copy link

andrewsiah commented Feb 6, 2025

Somehow mine has near 0% accuracy when running on Qwen2.5 0.5B base (while others have reported 40+%).

I found my model to not conform to the response XML format of , wonder what your prompts are to replicate the results are?

Screenshot 2025-02-06 at 12 01 01 AM

My configs:

training:
  output_dir: "outputs/Qwen-0.5B-GRPO"
  run_name: "Qwen-0.5B-GRPO-gsm8k"
  learning_rate: 5.0e-6
  gradient_checkpointing: true
  # Evaluation settings
  do_eval: true
  eval_steps: 50
  per_device_eval_batch_size: 128  # Adjust based on your GPU memory
  eval_strategy: "steps"
  beta: 0.02

My prompt:

prompts:
  system_prompt: |
    Respond in the following format:

    <reasoning>
    {reasoning}
    </reasoning>
    <answer>
    {answer}
    </answer>

  xml_cot_format: |
    <reasoning>
    {reasoning}
    </reasoning>
    <answer>
    {answer}
    </answer>

@NickyDark1
Copy link

What would it be like to be able to train in 'multi-turn'?

@willccbb
Copy link
Author

willccbb commented Feb 6, 2025

@NickyDark1 will have something to show you very soon on that :)

@cgpeter96
Copy link

A out-of-the box DeepSpeed version for running grpo on multiple GPUS, e.g., 8xA100. https://gist.github.com/cgpeter96/53ffcd5b49c10e8de5303059c21388ac

@HaileyStorm
Copy link

@ianand

@HaileyStorm re: willccbb/4676755236bb08cab5f4e54a0475d6fb?permalink_comment_id=5423553#gistcomment-5423553

I was wrong, my issue was with how trl is calculating the GRPO update. Note: No vllm here.

Can you elaborate on how you determined that?

Please see the screenshots. "bs" there is batch size and "ga" is gradient accumulation. Everything else is equal between runs. That, and I've seen commentary on the trl GRPO calculation being technically correct only for (I think device batch size = 1 is the constraint?)

@HaileyStorm
Copy link

HaileyStorm commented Feb 7, 2025

@andrewsiah

How do you all add extra metrics to log onto wandb when running the file with accelerate launch?

Simply import wandb at the top of your file, and in your reward function do:

wandb.log({
    "debug/question": q,
    "debug/true_answer": answer[i],
    "debug/response": responses[i],
    "debug/extracted_response": extracted_responses[i],
    "debug/correct": extracted_responses[i] == answer[i],
    "debug/generation_idx": i,  # Which generation attempt this was
}, step=current_step)

Notice the inclusion of the "debug" logging group; call that what you want (all the existing stuff goes in "train" or "train/reward").

ETA: This works because by the time the reward function gets called, the Trainer will have already done the wandb.init().

@NickyDark1
Copy link

@willccbb, great, thank you very much for that script it made me understand more

@andrewsiah
Copy link

@andrewsiah

How do you all add extra metrics to log onto wandb when running the file with accelerate launch?

Simply import wandb at the top of your file, and in your reward function do:

wandb.log({
    "debug/question": q,
    "debug/true_answer": answer[i],
    "debug/response": responses[i],
    "debug/extracted_response": extracted_responses[i],
    "debug/correct": extracted_responses[i] == answer[i],
    "debug/generation_idx": i,  # Which generation attempt this was
}, step=current_step)

Notice the inclusion of the "debug" logging group; call that what you want (all the existing stuff goes in "train" or "train/reward").

ETA: This works because by the time the reward function gets called, the Trainer will have already done the wandb.init().

Thank you so much!

@ahaym
Copy link

ahaym commented Feb 9, 2025

FYI: There were changes from trl@cf97133 that change the relationship between num_generations and per_device_train_batch_size that could lead to these errors:

The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly divisible by the number of generations per prompt ({self.num_generations})

To get the same behavior as before, set both num_generations and per_device_train_batch_size to 16. If you want to split generations across GPUs for a lighter memory load, you can even set per_device_train_batch_size even lower, see the PR.

@NickyDark1
Copy link

ahaym thanks

@ggjge
Copy link

ggjge commented Feb 12, 2025

After implementing some optimizations to the grpo trainer and tweaking params, I'm successfully running training of qwen2.5-0.5B-instruct on a free google colab T4 GPU, at ~13hours/epoch. There's hope for the GPU poor image

Will be posting updates here

@qunash can you share your parameters?

@ggjge
Copy link

ggjge commented Feb 12, 2025

  • Tune beta to 0.01

do you mean adam_beta1 = 0.01 and adam_beta2 = 0.01 ? the default values are 0.9 and 0.99

@zaddy6
Copy link

zaddy6 commented Feb 12, 2025

Has any one been able to train anything >3B I still cant finetune large models without OOM even with a H100x8

@fsxbhyy
Copy link

fsxbhyy commented Feb 12, 2025

Thanks willccbb!
When training on the GPU with qwen model, I encountered the error:
" probability tensor contains either inf, nan or element < 0"
I would appreciate any insights on a possible solution.

@ggjge
Copy link

ggjge commented Feb 14, 2025

Trained on runpod, took about 90 minutes to do 250 steps

@qunash

I was able to get it running using a similar setup as @qrdlgit above:

Trained on runpod, took about 90 minutes to do 250 steps
1x H100 NVL (94 GB VRAM)
94 GB RAM • 16 vCPU
Total Disk: 40 GB
pip install git+https://github.com/huggingface/trl.git accelerate transformers datasets peft wandb tqdm
Note that i had to pip install flash_attn as well

But I had to install a specific version of flash-attn to get it to work:

!pip install flash-attn==2.3.6

image

thx @willccbb.

@ianand Could you please share your detail configs? I just tried many settings but the reward remains at the original level and can't get anying changing

@willccbb
Copy link
Author

closely related project i'm working on to make RL with verifiers easier: https://github.com/willccbb/verifiers

currently the main focus is on supporting multi-step rollouts (tool use, multi-agent envs, games, code repls, etc)

to make an "environment" for running TRL GRPO (vLLM-only), all that's needed is to extend MultiStepEnv with methods that compute env responses + decide when a trajectory is finished (at the [{'role': 'user', 'content' : ...}] level, no tensor slicing required)

class MultiStepEnv(BaseEnv):
    def __init__(self,
                 system_prompt: str = "",
                 few_shot: List[Dict[str, str]] = [],
                 sampling_args: Dict[str, Any] = {},
                 **kwargs):
        super().__init__(**kwargs)
        self.system_prompt = system_prompt
        self.few_shot = few_shot
        self.sampling_args = sampling_args

    @abstractmethod
    def is_completed(self, messages: List[Dict[str, str]], **kwargs: Any) -> bool:
        pass

    @abstractmethod
    def env_response(self, messages: List[Dict[str, str]], **kwargs: Any) -> Dict[str, str]:
        pass

will be adding more env examples in the coming days/weeks

also have some basic support for encapsulating dataset/rubric construction inside envs

if you're a fan of grpo_demo.py, please consider checking it out :)

@AlexChaloner
Copy link

I think I've spotted an error in the formatting reward regexes. Someone correct me if I'm wrong!
E.g.
r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
Isn't working great because .*? is only selecting characters without newlines, and the reasoning normally includes newlines.
This directly contradicts the other xml reward which is incentivising newlines, i.e. \n</reasoning>\n

My suggestion is instead:
r"<reasoning>[\s\S]*</reasoning>\s*<answer>.*?</answer>"
i.e. the [\s\S]* selects all non-whitespace and whitespace characters. Alternatively can use re.match(pattern, r, flags=re.DOTALL).

@willccbb
Copy link
Author

@AlexChaloner yes good catch, fixed

@junming-yang
Copy link

为什么会出现越来越多的(\boxed{ANSWE})? 代码对于这个没有任何奖励,为什么训练过程中这个越来越多?

应该是Qwen的预训练数据有这个格式

@ggjge
Copy link

ggjge commented Feb 18, 2025

image 有人遇到过reward不变的情况吗?试了很多版超参甚至做了冷启,还是一动不动

@qunash
Copy link

qunash commented Feb 18, 2025

After implementing some optimizations to the grpo trainer and tweaking params, I'm successfully running training of qwen2.5-0.5B-instruct on a free google colab T4 GPU, at ~13hours/epoch. There's hope for the GPU poor image
Will be posting updates here

@qunash can you share your parameters?

https://gist.github.com/qunash/820c86d1d267ec8051d9f68b4f4bb656

@zdgithub
Copy link

zdgithub commented Feb 19, 2025

Somehow mine has near 0% accuracy when running on Qwen2.5 0.5B base (while others have reported 40+%).

I found my model to not conform to the response XML format of , wonder what your prompts are to replicate the results are?

Screenshot 2025-02-06 at 12 01 01 AM

My configs:

training:
  output_dir: "outputs/Qwen-0.5B-GRPO"
  run_name: "Qwen-0.5B-GRPO-gsm8k"
  learning_rate: 5.0e-6
  gradient_checkpointing: true
  # Evaluation settings
  do_eval: true
  eval_steps: 50
  per_device_eval_batch_size: 128  # Adjust based on your GPU memory
  eval_strategy: "steps"
  beta: 0.02

My prompt:

prompts:
  system_prompt: |
    Respond in the following format:

    <reasoning>
    {reasoning}
    </reasoning>
    <answer>
    {answer}
    </answer>

  xml_cot_format: |
    <reasoning>
    {reasoning}
    </reasoning>
    <answer>
    {answer}
    </answer>

@andrewsiah Hi,I encountered the same problem, the reward is always 0. Did you solve it?

@ggjge
Copy link

ggjge commented Feb 19, 2025

After implementing some optimizations to the grpo trainer and tweaking params, I'm successfully running training of qwen2.5-0.5B-instruct on a free google colab T4 GPU, at ~13hours/epoch. There's hope for the GPU poor image
Will be posting updates here

@qunash can you share your parameters?

https://gist.github.com/qunash/820c86d1d267ec8051d9f68b4f4bb656

@qunash
Thx a lot for your sharing!
I have a question about your format reward pattern, it seems only matches a string which equals to \n ?
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment