Skip to content

Instantly share code, notes, and snippets.

@luis-c465
Created October 24, 2023 21:55
Show Gist options
  • Save luis-c465/443d66404053855b92e35dc41f09540b to your computer and use it in GitHub Desktop.
Save luis-c465/443d66404053855b92e35dc41f09540b to your computer and use it in GitHub Desktop.
Action Masked PPO for Connect Four in Jyupter Notebook Form
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Uses Stable-Baselines3 to train agents in the Connect Four environment using invalid action masking.\n",
"\n",
"For information about invalid action masking in PettingZoo, see https://pettingzoo.farama.org/api/aec/#action-masking\n",
"For more information about invalid action masking in SB3, see https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html\n",
"\n",
"Author: Elliot (https://github.com/elliottower)\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"import os\n",
"import time\n",
"\n",
"import pettingzoo.utils\n",
"from pettingzoo.classic import connect_four_v3\n",
"from sb3_contrib import MaskablePPO\n",
"from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy\n",
"from sb3_contrib.common.wrappers import ActionMasker\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper):\n",
" \"\"\"Wrapper to allow PettingZoo environments to be used with SB3 illegal action masking.\"\"\"\n",
"\n",
" def reset(self, seed=None, options=None):\n",
" \"\"\"Gymnasium-like reset function which assigns obs/action spaces to be the same for each agent.\n",
"\n",
" This is required as SB3 is designed for single-agent RL and doesn't expect obs/action spaces to be functions\n",
" \"\"\"\n",
" super().reset(seed, options)\n",
"\n",
" # Strip the action mask out from the observation space\n",
" self.observation_space = super().observation_space(self.possible_agents[0])[\n",
" \"observation\"\n",
" ]\n",
" self.action_space = super().action_space(self.possible_agents[0])\n",
"\n",
" # Return initial observation, info (PettingZoo AEC envs do not by default)\n",
" return self.observe(self.agent_selection), {}\n",
"\n",
" def step(self, action):\n",
" \"\"\"Gymnasium-like step function, returning observation, reward, termination, truncation, info.\"\"\"\n",
" super().step(action)\n",
" return super().last()\n",
"\n",
" def observe(self, agent):\n",
" \"\"\"Return only raw observation, removing action mask.\"\"\"\n",
" return super().observe(agent)[\"observation\"]\n",
"\n",
" def action_mask(self):\n",
" \"\"\"Separate function used in order to access the action mask.\"\"\"\n",
" return super().observe(self.agent_selection)[\"action_mask\"]\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def mask_fn(env):\n",
" # Do whatever you'd like in this function to return the action mask\n",
" # for the current env. In this example, we assume the env has a\n",
" # helpful method we can rely on.\n",
" return env.action_mask()\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def train_action_mask(env_fn, steps=10_000, seed=0, **env_kwargs):\n",
" \"\"\"Train a single model to play as each agent in a zero-sum game environment using invalid action masking.\"\"\"\n",
" env = env_fn.env(**env_kwargs)\n",
"\n",
" print(f\"Starting training on {str(env.metadata['name'])}.\")\n",
"\n",
" # Custom wrapper to convert PettingZoo envs to work with SB3 action masking\n",
" env = SB3ActionMaskWrapper(env)\n",
"\n",
" env.reset(seed=seed) # Must call reset() in order to re-define the spaces\n",
"\n",
" env = ActionMasker(env, mask_fn) # Wrap to enable masking (SB3 function)\n",
" # MaskablePPO behaves the same as SB3's PPO unless the env is wrapped\n",
" # with ActionMasker. If the wrapper is detected, the masks are automatically\n",
" # retrieved and used when learning. Note that MaskablePPO does not accept\n",
" # a new action_mask_fn kwarg, as it did in an earlier draft.\n",
" model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)\n",
" model.set_random_seed(seed)\n",
" model.learn(total_timesteps=steps)\n",
"\n",
" model.save(f\"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}\")\n",
"\n",
" print(\"Model has been saved.\")\n",
"\n",
" print(f\"Finished training on {str(env.unwrapped.metadata['name'])}.\\n\")\n",
"\n",
" env.close()\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs):\n",
" # Evaluate a trained agent vs a random agent\n",
" env = env_fn.env(render_mode=render_mode, **env_kwargs)\n",
"\n",
" print(\n",
" f\"Starting evaluation vs a random agent. Trained agent will play as {env.possible_agents[1]}.\"\n",
" )\n",
"\n",
" try:\n",
" latest_policy = max(\n",
" glob.glob(f\"{env.metadata['name']}*.zip\"), key=os.path.getctime\n",
" )\n",
" except ValueError:\n",
" print(\"Policy not found.\")\n",
" exit(0)\n",
"\n",
" model = MaskablePPO.load(latest_policy)\n",
"\n",
" scores = {agent: 0 for agent in env.possible_agents}\n",
" total_rewards = {agent: 0 for agent in env.possible_agents}\n",
" round_rewards = []\n",
"\n",
" for i in range(num_games):\n",
" env.reset(seed=i)\n",
" env.action_space(env.possible_agents[0]).seed(i)\n",
"\n",
" for agent in env.agent_iter():\n",
" obs, reward, termination, truncation, info = env.last()\n",
"\n",
" # Separate observation and action mask\n",
" observation, action_mask = obs.values()\n",
"\n",
" if termination or truncation:\n",
" # If there is a winner, keep track, otherwise don't change the scores (tie)\n",
" if (\n",
" env.rewards[env.possible_agents[0]]\n",
" != env.rewards[env.possible_agents[1]]\n",
" ):\n",
" winner = max(env.rewards, key=env.rewards.get)\n",
" scores[winner] += env.rewards[\n",
" winner\n",
" ] # only tracks the largest reward (winner of game)\n",
" # Also track negative and positive rewards (penalizes illegal moves)\n",
" for a in env.possible_agents:\n",
" total_rewards[a] += env.rewards[a]\n",
" # List of rewards by round, for reference\n",
" round_rewards.append(env.rewards)\n",
" break\n",
" else:\n",
" if agent == env.possible_agents[0]:\n",
" act = env.action_space(agent).sample(action_mask)\n",
" else:\n",
" # Note: PettingZoo expects integer actions # TODO: change chess to cast actions to type int?\n",
" act = int(\n",
" model.predict(\n",
" observation, action_masks=action_mask, deterministic=True\n",
" )[0]\n",
" )\n",
" env.step(act)\n",
" env.close()\n",
"\n",
" # Avoid dividing by zero\n",
" if sum(scores.values()) == 0:\n",
" winrate = 0\n",
" else:\n",
" winrate = scores[env.possible_agents[1]] / sum(scores.values())\n",
" print(\"Rewards by round: \", round_rewards)\n",
" print(\"Total rewards (incl. negative rewards): \", total_rewards)\n",
" print(\"Winrate: \", winrate)\n",
" print(\"Final scores: \", scores)\n",
" return round_rewards, total_rewards, winrate, scores\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"env_fn = connect_four_v3\n",
"\n",
"env_kwargs = {}\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting training on connect_four_v3.\n",
"Using cpu device\n",
"Wrapping the env with a `Monitor` wrapper\n",
"Wrapping the env in a DummyVecEnv.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/luis/dev/ai-guild/.venv/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.action_masks to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.action_masks` for environment variables or `env.get_wrapper_attr('action_masks')` that will search the reminding wrappers.\u001b[0m\n",
" logger.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 20.5 |\n",
"| ep_rew_mean | -1 |\n",
"| time/ | |\n",
"| fps | 3098 |\n",
"| iterations | 1 |\n",
"| time_elapsed | 0 |\n",
"| total_timesteps | 2048 |\n",
"---------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 21 |\n",
"| ep_rew_mean | -1 |\n",
"| time/ | |\n",
"| fps | 2282 |\n",
"| iterations | 2 |\n",
"| time_elapsed | 1 |\n",
"| total_timesteps | 4096 |\n",
"| train/ | |\n",
"| approx_kl | 0.009072401 |\n",
"| clip_fraction | 0.0533 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -1.91 |\n",
"| explained_variance | -1.71 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | -0.0253 |\n",
"| n_updates | 10 |\n",
"| policy_gradient_loss | -0.0187 |\n",
"| value_loss | 0.0943 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 22.4 |\n",
"| ep_rew_mean | -1 |\n",
"| time/ | |\n",
"| fps | 2133 |\n",
"| iterations | 3 |\n",
"| time_elapsed | 2 |\n",
"| total_timesteps | 6144 |\n",
"| train/ | |\n",
"| approx_kl | 0.010119117 |\n",
"| clip_fraction | 0.0917 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -1.89 |\n",
"| explained_variance | 0.0214 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 0.00375 |\n",
"| n_updates | 20 |\n",
"| policy_gradient_loss | -0.0275 |\n",
"| value_loss | 0.0154 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 23.1 |\n",
"| ep_rew_mean | -1 |\n",
"| time/ | |\n",
"| fps | 2048 |\n",
"| iterations | 4 |\n",
"| time_elapsed | 3 |\n",
"| total_timesteps | 8192 |\n",
"| train/ | |\n",
"| approx_kl | 0.011137698 |\n",
"| clip_fraction | 0.0876 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -1.85 |\n",
"| explained_variance | -0.589 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | -0.0483 |\n",
"| n_updates | 30 |\n",
"| policy_gradient_loss | -0.0265 |\n",
"| value_loss | 0.00647 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 23.5 |\n",
"| ep_rew_mean | -1 |\n",
"| time/ | |\n",
"| fps | 2018 |\n",
"| iterations | 5 |\n",
"| time_elapsed | 5 |\n",
"| total_timesteps | 10240 |\n",
"| train/ | |\n",
"| approx_kl | 0.011213706 |\n",
"| clip_fraction | 0.0866 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -1.83 |\n",
"| explained_variance | -0.381 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | -0.0295 |\n",
"| n_updates | 40 |\n",
"| policy_gradient_loss | -0.0276 |\n",
"| value_loss | 0.00417 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 24.7 |\n",
"| ep_rew_mean | -1 |\n",
"| time/ | |\n",
"| fps | 1999 |\n",
"| iterations | 6 |\n",
"| time_elapsed | 6 |\n",
"| total_timesteps | 12288 |\n",
"| train/ | |\n",
"| approx_kl | 0.011083022 |\n",
"| clip_fraction | 0.115 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -1.79 |\n",
"| explained_variance | -0.283 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | -0.0398 |\n",
"| n_updates | 50 |\n",
"| policy_gradient_loss | -0.0308 |\n",
"| value_loss | 0.00354 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 24.9 |\n",
"| ep_rew_mean | -1 |\n",
"| time/ | |\n",
"| fps | 1985 |\n",
"| iterations | 7 |\n",
"| time_elapsed | 7 |\n",
"| total_timesteps | 14336 |\n",
"| train/ | |\n",
"| approx_kl | 0.010872293 |\n",
"| clip_fraction | 0.0815 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -1.74 |\n",
"| explained_variance | -0.0488 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | -0.027 |\n",
"| n_updates | 60 |\n",
"| policy_gradient_loss | -0.029 |\n",
"| value_loss | 0.00258 |\n",
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 24.6 |\n",
"| ep_rew_mean | -1 |\n",
"| time/ | |\n",
"| fps | 1977 |\n",
"| iterations | 8 |\n",
"| time_elapsed | 8 |\n",
"| total_timesteps | 16384 |\n",
"| train/ | |\n",
"| approx_kl | 0.01180191 |\n",
"| clip_fraction | 0.108 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -1.74 |\n",
"| explained_variance | 0.00113 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | -0.0171 |\n",
"| n_updates | 70 |\n",
"| policy_gradient_loss | -0.0333 |\n",
"| value_loss | 0.0026 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 26.2 |\n",
"| ep_rew_mean | -1 |\n",
"| time/ | |\n",
"| fps | 1969 |\n",
"| iterations | 9 |\n",
"| time_elapsed | 9 |\n",
"| total_timesteps | 18432 |\n",
"| train/ | |\n",
"| approx_kl | 0.010825078 |\n",
"| clip_fraction | 0.0967 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -1.71 |\n",
"| explained_variance | 0.316 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | -0.0578 |\n",
"| n_updates | 80 |\n",
"| policy_gradient_loss | -0.0308 |\n",
"| value_loss | 0.00203 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 25.9 |\n",
"| ep_rew_mean | -1 |\n",
"| time/ | |\n",
"| fps | 1956 |\n",
"| iterations | 10 |\n",
"| time_elapsed | 10 |\n",
"| total_timesteps | 20480 |\n",
"| train/ | |\n",
"| approx_kl | 0.012171678 |\n",
"| clip_fraction | 0.133 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -1.65 |\n",
"| explained_variance | 0.342 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | -0.0483 |\n",
"| n_updates | 90 |\n",
"| policy_gradient_loss | -0.0355 |\n",
"| value_loss | 0.00169 |\n",
"-----------------------------------------\n",
"Model has been saved.\n",
"Finished training on connect_four_v3.\n",
"\n"
]
}
],
"source": [
"# Evaluation/training hyperparameter notes:\n",
"# 10k steps: Winrate: 0.76, loss order of 1e-03\n",
"# 20k steps: Winrate: 0.86, loss order of 1e-04\n",
"# 40k steps: Winrate: 0.86, loss order of 7e-06\n",
"train_action_mask(env_fn, steps=20_480, seed=0, **env_kwargs)\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting evaluation vs a random agent. Trained agent will play as player_1.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/luis/dev/ai-guild/.venv/lib/python3.11/site-packages/stable_baselines3/common/save_util.py:166: UserWarning: Could not deserialize object clip_range. Consider using `custom_objects` argument to replace this object.\n",
"Exception: code() argument 13 must be str, not int\n",
" warnings.warn(\n",
"/Users/luis/dev/ai-guild/.venv/lib/python3.11/site-packages/stable_baselines3/common/save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n",
"Exception: code() argument 13 must be str, not int\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Rewards by round: [{'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}]\n",
"Total rewards (incl. negative rewards): {'player_0': 62, 'player_1': -62}\n",
"Winrate: 0.19\n",
"Final scores: {'player_0': 81, 'player_1': 19}\n"
]
},
{
"data": {
"text/plain": [
"([{'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': 1, 'player_1': -1},\n",
" {'player_0': -1, 'player_1': 1},\n",
" {'player_0': 1, 'player_1': -1}],\n",
" {'player_0': 62, 'player_1': -62},\n",
" 0.19,\n",
" {'player_0': 81, 'player_1': 19})"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs)\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting evaluation vs a random agent. Trained agent will play as player_1.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-10-24 17:51:23.930 Python[49699:10381252] ApplePersistenceIgnoreState: Existing state will not be touched. New state will be written to /var/folders/ds/ybrs4jpx7j16cd0skpzk_xfw0000gn/T/org.python.python.savedState\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Rewards by round: [{'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}]\n",
"Total rewards (incl. negative rewards): {'player_0': 2, 'player_1': -2}\n",
"Winrate: 0.0\n",
"Final scores: {'player_0': 2, 'player_1': 0}\n"
]
},
{
"data": {
"text/plain": [
"([{'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}],\n",
" {'player_0': 2, 'player_1': -2},\n",
" 0.0,\n",
" {'player_0': 2, 'player_1': 0})"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eval_action_mask(env_fn, num_games=2, render_mode=\"human\", **env_kwargs)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment