This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | from jaxtyping import Float, Int | |
| import torch | |
| from torch.nn import functional as F | |
| from torch import Tensor | |
| from typing import List, Callable, Tuple, Dict, Optional | |
| import pandas as pd | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| def get_valid_next_choices(choices_tokens, current_tokens): | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # train_grpo.py | |
| import re | |
| from datasets import load_dataset, Dataset | |
| from transformers import AutoTokenizer | |
| from peft import LoraConfig | |
| from trl import GRPOConfig, GRPOTrainer | |
| # Load and prep dataset | |
| SYSTEM_PROMPT = """ |