Last active
February 10, 2025 05:08
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# --- GRPO to teach poker (using llm-poker) --- | |
from llm_poker.environment import PokerTable | |
from llm_poker.llm_player import LLMPlayer | |
from datasets import Dataset | |
from trl import GRPOConfig, GRPOTrainer | |
import re | |
import json | |
class GRPOPlayer(LLMPlayer): | |
""" | |
A custom player that bypasses the usual LLM action by returning a pre-specified move. | |
""" | |
def __init__(self, name, model_id, stack, provided_move): | |
super().__init__(name=name, model_id=model_id, stack=stack) | |
self.provided_move = provided_move | |
def request_action(self, **kwargs): | |
# Always return the move provided by GRPO training | |
return {"action": self.provided_move, "raise_amount": None} | |
def agent_reward(completions, **kwargs): | |
rewards = [] | |
for completion in completions: | |
content = completion[0]["content"] | |
try: | |
# Parse the entire output as JSON. Expect format: {"action": "call"} (or "fold"/"raise") | |
json_data = json.loads(content) | |
move = json_data.get("action") | |
if move not in {"fold", "call", "raise"}: | |
rewards.append(-100) # Invalid action provided. | |
continue | |
except json.JSONDecodeError: | |
rewards.append(-80) # Malformed JSON output. | |
continue | |
starting_stack = 10000 | |
# Create our agent with the provided move. | |
# Create our agent with the provided move. | |
agent = GRPOPlayer(name="Agent", model_id="gpt-4o", stack=starting_stack, provided_move=move) | |
# Create two opponents using the default behavior from LLMPlayer. | |
opponent1 = LLMPlayer(name="Opponent1", model_id="gpt-4o", stack=starting_stack) | |
opponent2 = LLMPlayer(name="Opponent2", model_id="gpt-4o", stack=starting_stack) | |
# Set up the poker table with the desired blinds and min_raise. | |
table = PokerTable( | |
players=[agent, opponent1, opponent2], | |
min_raise=500, | |
small_blind=50, | |
big_blind=100 | |
) | |
# Play one hand of poker. The table internally deals cards and runs betting rounds. | |
table.play_hand() | |
# Compute the reward as the net change in the agent's chip stack. | |
reward = agent.stack - starting_stack | |
rewards.append(reward) | |
return rewards | |
dataset = Dataset.from_list( | |
[{ | |
"prompt": [{ | |
"role": "user", | |
"content": "Let's play a round of poker. What is your move? Please respond with a JSON object (e.g., {\"action\": \"call\"})." | |
}] | |
}] * 200 | |
) | |
def main(): | |
training_args = GRPOConfig( | |
output_dir="Qwen2.5-0.5B-GRPO-Pokeragent", | |
logging_steps=5, | |
gradient_accumulation_steps=4, | |
max_completion_length=128, | |
max_prompt_length=128, | |
bf16=True, | |
# log_completions=True, | |
) | |
trainer = GRPOTrainer( | |
model="Qwen/Qwen2.5-0.5B-Instruct", | |
reward_funcs=agent_reward, | |
args=training_args, | |
train_dataset=dataset, | |
# tools=[get_value], | |
) | |
trainer.train() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment