Skip to content

Instantly share code, notes, and snippets.

@strangeloopcanon
Last active February 10, 2025 05:08
# --- GRPO to teach poker (using llm-poker) ---
from llm_poker.environment import PokerTable
from llm_poker.llm_player import LLMPlayer
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
import re
import json
class GRPOPlayer(LLMPlayer):
"""
A custom player that bypasses the usual LLM action by returning a pre-specified move.
"""
def __init__(self, name, model_id, stack, provided_move):
super().__init__(name=name, model_id=model_id, stack=stack)
self.provided_move = provided_move
def request_action(self, **kwargs):
# Always return the move provided by GRPO training
return {"action": self.provided_move, "raise_amount": None}
def agent_reward(completions, **kwargs):
rewards = []
for completion in completions:
content = completion[0]["content"]
try:
# Parse the entire output as JSON. Expect format: {"action": "call"} (or "fold"/"raise")
json_data = json.loads(content)
move = json_data.get("action")
if move not in {"fold", "call", "raise"}:
rewards.append(-100) # Invalid action provided.
continue
except json.JSONDecodeError:
rewards.append(-80) # Malformed JSON output.
continue
starting_stack = 10000
# Create our agent with the provided move.
# Create our agent with the provided move.
agent = GRPOPlayer(name="Agent", model_id="gpt-4o", stack=starting_stack, provided_move=move)
# Create two opponents using the default behavior from LLMPlayer.
opponent1 = LLMPlayer(name="Opponent1", model_id="gpt-4o", stack=starting_stack)
opponent2 = LLMPlayer(name="Opponent2", model_id="gpt-4o", stack=starting_stack)
# Set up the poker table with the desired blinds and min_raise.
table = PokerTable(
players=[agent, opponent1, opponent2],
min_raise=500,
small_blind=50,
big_blind=100
)
# Play one hand of poker. The table internally deals cards and runs betting rounds.
table.play_hand()
# Compute the reward as the net change in the agent's chip stack.
reward = agent.stack - starting_stack
rewards.append(reward)
return rewards
dataset = Dataset.from_list(
[{
"prompt": [{
"role": "user",
"content": "Let's play a round of poker. What is your move? Please respond with a JSON object (e.g., {\"action\": \"call\"})."
}]
}] * 200
)
def main():
training_args = GRPOConfig(
output_dir="Qwen2.5-0.5B-GRPO-Pokeragent",
logging_steps=5,
gradient_accumulation_steps=4,
max_completion_length=128,
max_prompt_length=128,
bf16=True,
# log_completions=True,
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=agent_reward,
args=training_args,
train_dataset=dataset,
# tools=[get_value],
)
trainer.train()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment