Created
October 24, 2023 21:55
-
-
Save luis-c465/443d66404053855b92e35dc41f09540b to your computer and use it in GitHub Desktop.
Action Masked PPO for Connect Four in Jyupter Notebook Form
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Uses Stable-Baselines3 to train agents in the Connect Four environment using invalid action masking.\n", | |
"\n", | |
"For information about invalid action masking in PettingZoo, see https://pettingzoo.farama.org/api/aec/#action-masking\n", | |
"For more information about invalid action masking in SB3, see https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html\n", | |
"\n", | |
"Author: Elliot (https://github.com/elliottower)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import glob\n", | |
"import os\n", | |
"import time\n", | |
"\n", | |
"import pettingzoo.utils\n", | |
"from pettingzoo.classic import connect_four_v3\n", | |
"from sb3_contrib import MaskablePPO\n", | |
"from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy\n", | |
"from sb3_contrib.common.wrappers import ActionMasker\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper):\n", | |
" \"\"\"Wrapper to allow PettingZoo environments to be used with SB3 illegal action masking.\"\"\"\n", | |
"\n", | |
" def reset(self, seed=None, options=None):\n", | |
" \"\"\"Gymnasium-like reset function which assigns obs/action spaces to be the same for each agent.\n", | |
"\n", | |
" This is required as SB3 is designed for single-agent RL and doesn't expect obs/action spaces to be functions\n", | |
" \"\"\"\n", | |
" super().reset(seed, options)\n", | |
"\n", | |
" # Strip the action mask out from the observation space\n", | |
" self.observation_space = super().observation_space(self.possible_agents[0])[\n", | |
" \"observation\"\n", | |
" ]\n", | |
" self.action_space = super().action_space(self.possible_agents[0])\n", | |
"\n", | |
" # Return initial observation, info (PettingZoo AEC envs do not by default)\n", | |
" return self.observe(self.agent_selection), {}\n", | |
"\n", | |
" def step(self, action):\n", | |
" \"\"\"Gymnasium-like step function, returning observation, reward, termination, truncation, info.\"\"\"\n", | |
" super().step(action)\n", | |
" return super().last()\n", | |
"\n", | |
" def observe(self, agent):\n", | |
" \"\"\"Return only raw observation, removing action mask.\"\"\"\n", | |
" return super().observe(agent)[\"observation\"]\n", | |
"\n", | |
" def action_mask(self):\n", | |
" \"\"\"Separate function used in order to access the action mask.\"\"\"\n", | |
" return super().observe(self.agent_selection)[\"action_mask\"]\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def mask_fn(env):\n", | |
" # Do whatever you'd like in this function to return the action mask\n", | |
" # for the current env. In this example, we assume the env has a\n", | |
" # helpful method we can rely on.\n", | |
" return env.action_mask()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train_action_mask(env_fn, steps=10_000, seed=0, **env_kwargs):\n", | |
" \"\"\"Train a single model to play as each agent in a zero-sum game environment using invalid action masking.\"\"\"\n", | |
" env = env_fn.env(**env_kwargs)\n", | |
"\n", | |
" print(f\"Starting training on {str(env.metadata['name'])}.\")\n", | |
"\n", | |
" # Custom wrapper to convert PettingZoo envs to work with SB3 action masking\n", | |
" env = SB3ActionMaskWrapper(env)\n", | |
"\n", | |
" env.reset(seed=seed) # Must call reset() in order to re-define the spaces\n", | |
"\n", | |
" env = ActionMasker(env, mask_fn) # Wrap to enable masking (SB3 function)\n", | |
" # MaskablePPO behaves the same as SB3's PPO unless the env is wrapped\n", | |
" # with ActionMasker. If the wrapper is detected, the masks are automatically\n", | |
" # retrieved and used when learning. Note that MaskablePPO does not accept\n", | |
" # a new action_mask_fn kwarg, as it did in an earlier draft.\n", | |
" model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)\n", | |
" model.set_random_seed(seed)\n", | |
" model.learn(total_timesteps=steps)\n", | |
"\n", | |
" model.save(f\"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}\")\n", | |
"\n", | |
" print(\"Model has been saved.\")\n", | |
"\n", | |
" print(f\"Finished training on {str(env.unwrapped.metadata['name'])}.\\n\")\n", | |
"\n", | |
" env.close()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs):\n", | |
" # Evaluate a trained agent vs a random agent\n", | |
" env = env_fn.env(render_mode=render_mode, **env_kwargs)\n", | |
"\n", | |
" print(\n", | |
" f\"Starting evaluation vs a random agent. Trained agent will play as {env.possible_agents[1]}.\"\n", | |
" )\n", | |
"\n", | |
" try:\n", | |
" latest_policy = max(\n", | |
" glob.glob(f\"{env.metadata['name']}*.zip\"), key=os.path.getctime\n", | |
" )\n", | |
" except ValueError:\n", | |
" print(\"Policy not found.\")\n", | |
" exit(0)\n", | |
"\n", | |
" model = MaskablePPO.load(latest_policy)\n", | |
"\n", | |
" scores = {agent: 0 for agent in env.possible_agents}\n", | |
" total_rewards = {agent: 0 for agent in env.possible_agents}\n", | |
" round_rewards = []\n", | |
"\n", | |
" for i in range(num_games):\n", | |
" env.reset(seed=i)\n", | |
" env.action_space(env.possible_agents[0]).seed(i)\n", | |
"\n", | |
" for agent in env.agent_iter():\n", | |
" obs, reward, termination, truncation, info = env.last()\n", | |
"\n", | |
" # Separate observation and action mask\n", | |
" observation, action_mask = obs.values()\n", | |
"\n", | |
" if termination or truncation:\n", | |
" # If there is a winner, keep track, otherwise don't change the scores (tie)\n", | |
" if (\n", | |
" env.rewards[env.possible_agents[0]]\n", | |
" != env.rewards[env.possible_agents[1]]\n", | |
" ):\n", | |
" winner = max(env.rewards, key=env.rewards.get)\n", | |
" scores[winner] += env.rewards[\n", | |
" winner\n", | |
" ] # only tracks the largest reward (winner of game)\n", | |
" # Also track negative and positive rewards (penalizes illegal moves)\n", | |
" for a in env.possible_agents:\n", | |
" total_rewards[a] += env.rewards[a]\n", | |
" # List of rewards by round, for reference\n", | |
" round_rewards.append(env.rewards)\n", | |
" break\n", | |
" else:\n", | |
" if agent == env.possible_agents[0]:\n", | |
" act = env.action_space(agent).sample(action_mask)\n", | |
" else:\n", | |
" # Note: PettingZoo expects integer actions # TODO: change chess to cast actions to type int?\n", | |
" act = int(\n", | |
" model.predict(\n", | |
" observation, action_masks=action_mask, deterministic=True\n", | |
" )[0]\n", | |
" )\n", | |
" env.step(act)\n", | |
" env.close()\n", | |
"\n", | |
" # Avoid dividing by zero\n", | |
" if sum(scores.values()) == 0:\n", | |
" winrate = 0\n", | |
" else:\n", | |
" winrate = scores[env.possible_agents[1]] / sum(scores.values())\n", | |
" print(\"Rewards by round: \", round_rewards)\n", | |
" print(\"Total rewards (incl. negative rewards): \", total_rewards)\n", | |
" print(\"Winrate: \", winrate)\n", | |
" print(\"Final scores: \", scores)\n", | |
" return round_rewards, total_rewards, winrate, scores\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"env_fn = connect_four_v3\n", | |
"\n", | |
"env_kwargs = {}\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Starting training on connect_four_v3.\n", | |
"Using cpu device\n", | |
"Wrapping the env with a `Monitor` wrapper\n", | |
"Wrapping the env in a DummyVecEnv.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/luis/dev/ai-guild/.venv/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.action_masks to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.action_masks` for environment variables or `env.get_wrapper_attr('action_masks')` that will search the reminding wrappers.\u001b[0m\n", | |
" logger.warn(\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"---------------------------------\n", | |
"| rollout/ | |\n", | |
"| ep_len_mean | 20.5 |\n", | |
"| ep_rew_mean | -1 |\n", | |
"| time/ | |\n", | |
"| fps | 3098 |\n", | |
"| iterations | 1 |\n", | |
"| time_elapsed | 0 |\n", | |
"| total_timesteps | 2048 |\n", | |
"---------------------------------\n", | |
"-----------------------------------------\n", | |
"| rollout/ | |\n", | |
"| ep_len_mean | 21 |\n", | |
"| ep_rew_mean | -1 |\n", | |
"| time/ | |\n", | |
"| fps | 2282 |\n", | |
"| iterations | 2 |\n", | |
"| time_elapsed | 1 |\n", | |
"| total_timesteps | 4096 |\n", | |
"| train/ | |\n", | |
"| approx_kl | 0.009072401 |\n", | |
"| clip_fraction | 0.0533 |\n", | |
"| clip_range | 0.2 |\n", | |
"| entropy_loss | -1.91 |\n", | |
"| explained_variance | -1.71 |\n", | |
"| learning_rate | 0.0003 |\n", | |
"| loss | -0.0253 |\n", | |
"| n_updates | 10 |\n", | |
"| policy_gradient_loss | -0.0187 |\n", | |
"| value_loss | 0.0943 |\n", | |
"-----------------------------------------\n", | |
"-----------------------------------------\n", | |
"| rollout/ | |\n", | |
"| ep_len_mean | 22.4 |\n", | |
"| ep_rew_mean | -1 |\n", | |
"| time/ | |\n", | |
"| fps | 2133 |\n", | |
"| iterations | 3 |\n", | |
"| time_elapsed | 2 |\n", | |
"| total_timesteps | 6144 |\n", | |
"| train/ | |\n", | |
"| approx_kl | 0.010119117 |\n", | |
"| clip_fraction | 0.0917 |\n", | |
"| clip_range | 0.2 |\n", | |
"| entropy_loss | -1.89 |\n", | |
"| explained_variance | 0.0214 |\n", | |
"| learning_rate | 0.0003 |\n", | |
"| loss | 0.00375 |\n", | |
"| n_updates | 20 |\n", | |
"| policy_gradient_loss | -0.0275 |\n", | |
"| value_loss | 0.0154 |\n", | |
"-----------------------------------------\n", | |
"-----------------------------------------\n", | |
"| rollout/ | |\n", | |
"| ep_len_mean | 23.1 |\n", | |
"| ep_rew_mean | -1 |\n", | |
"| time/ | |\n", | |
"| fps | 2048 |\n", | |
"| iterations | 4 |\n", | |
"| time_elapsed | 3 |\n", | |
"| total_timesteps | 8192 |\n", | |
"| train/ | |\n", | |
"| approx_kl | 0.011137698 |\n", | |
"| clip_fraction | 0.0876 |\n", | |
"| clip_range | 0.2 |\n", | |
"| entropy_loss | -1.85 |\n", | |
"| explained_variance | -0.589 |\n", | |
"| learning_rate | 0.0003 |\n", | |
"| loss | -0.0483 |\n", | |
"| n_updates | 30 |\n", | |
"| policy_gradient_loss | -0.0265 |\n", | |
"| value_loss | 0.00647 |\n", | |
"-----------------------------------------\n", | |
"-----------------------------------------\n", | |
"| rollout/ | |\n", | |
"| ep_len_mean | 23.5 |\n", | |
"| ep_rew_mean | -1 |\n", | |
"| time/ | |\n", | |
"| fps | 2018 |\n", | |
"| iterations | 5 |\n", | |
"| time_elapsed | 5 |\n", | |
"| total_timesteps | 10240 |\n", | |
"| train/ | |\n", | |
"| approx_kl | 0.011213706 |\n", | |
"| clip_fraction | 0.0866 |\n", | |
"| clip_range | 0.2 |\n", | |
"| entropy_loss | -1.83 |\n", | |
"| explained_variance | -0.381 |\n", | |
"| learning_rate | 0.0003 |\n", | |
"| loss | -0.0295 |\n", | |
"| n_updates | 40 |\n", | |
"| policy_gradient_loss | -0.0276 |\n", | |
"| value_loss | 0.00417 |\n", | |
"-----------------------------------------\n", | |
"-----------------------------------------\n", | |
"| rollout/ | |\n", | |
"| ep_len_mean | 24.7 |\n", | |
"| ep_rew_mean | -1 |\n", | |
"| time/ | |\n", | |
"| fps | 1999 |\n", | |
"| iterations | 6 |\n", | |
"| time_elapsed | 6 |\n", | |
"| total_timesteps | 12288 |\n", | |
"| train/ | |\n", | |
"| approx_kl | 0.011083022 |\n", | |
"| clip_fraction | 0.115 |\n", | |
"| clip_range | 0.2 |\n", | |
"| entropy_loss | -1.79 |\n", | |
"| explained_variance | -0.283 |\n", | |
"| learning_rate | 0.0003 |\n", | |
"| loss | -0.0398 |\n", | |
"| n_updates | 50 |\n", | |
"| policy_gradient_loss | -0.0308 |\n", | |
"| value_loss | 0.00354 |\n", | |
"-----------------------------------------\n", | |
"-----------------------------------------\n", | |
"| rollout/ | |\n", | |
"| ep_len_mean | 24.9 |\n", | |
"| ep_rew_mean | -1 |\n", | |
"| time/ | |\n", | |
"| fps | 1985 |\n", | |
"| iterations | 7 |\n", | |
"| time_elapsed | 7 |\n", | |
"| total_timesteps | 14336 |\n", | |
"| train/ | |\n", | |
"| approx_kl | 0.010872293 |\n", | |
"| clip_fraction | 0.0815 |\n", | |
"| clip_range | 0.2 |\n", | |
"| entropy_loss | -1.74 |\n", | |
"| explained_variance | -0.0488 |\n", | |
"| learning_rate | 0.0003 |\n", | |
"| loss | -0.027 |\n", | |
"| n_updates | 60 |\n", | |
"| policy_gradient_loss | -0.029 |\n", | |
"| value_loss | 0.00258 |\n", | |
"-----------------------------------------\n", | |
"----------------------------------------\n", | |
"| rollout/ | |\n", | |
"| ep_len_mean | 24.6 |\n", | |
"| ep_rew_mean | -1 |\n", | |
"| time/ | |\n", | |
"| fps | 1977 |\n", | |
"| iterations | 8 |\n", | |
"| time_elapsed | 8 |\n", | |
"| total_timesteps | 16384 |\n", | |
"| train/ | |\n", | |
"| approx_kl | 0.01180191 |\n", | |
"| clip_fraction | 0.108 |\n", | |
"| clip_range | 0.2 |\n", | |
"| entropy_loss | -1.74 |\n", | |
"| explained_variance | 0.00113 |\n", | |
"| learning_rate | 0.0003 |\n", | |
"| loss | -0.0171 |\n", | |
"| n_updates | 70 |\n", | |
"| policy_gradient_loss | -0.0333 |\n", | |
"| value_loss | 0.0026 |\n", | |
"----------------------------------------\n", | |
"-----------------------------------------\n", | |
"| rollout/ | |\n", | |
"| ep_len_mean | 26.2 |\n", | |
"| ep_rew_mean | -1 |\n", | |
"| time/ | |\n", | |
"| fps | 1969 |\n", | |
"| iterations | 9 |\n", | |
"| time_elapsed | 9 |\n", | |
"| total_timesteps | 18432 |\n", | |
"| train/ | |\n", | |
"| approx_kl | 0.010825078 |\n", | |
"| clip_fraction | 0.0967 |\n", | |
"| clip_range | 0.2 |\n", | |
"| entropy_loss | -1.71 |\n", | |
"| explained_variance | 0.316 |\n", | |
"| learning_rate | 0.0003 |\n", | |
"| loss | -0.0578 |\n", | |
"| n_updates | 80 |\n", | |
"| policy_gradient_loss | -0.0308 |\n", | |
"| value_loss | 0.00203 |\n", | |
"-----------------------------------------\n", | |
"-----------------------------------------\n", | |
"| rollout/ | |\n", | |
"| ep_len_mean | 25.9 |\n", | |
"| ep_rew_mean | -1 |\n", | |
"| time/ | |\n", | |
"| fps | 1956 |\n", | |
"| iterations | 10 |\n", | |
"| time_elapsed | 10 |\n", | |
"| total_timesteps | 20480 |\n", | |
"| train/ | |\n", | |
"| approx_kl | 0.012171678 |\n", | |
"| clip_fraction | 0.133 |\n", | |
"| clip_range | 0.2 |\n", | |
"| entropy_loss | -1.65 |\n", | |
"| explained_variance | 0.342 |\n", | |
"| learning_rate | 0.0003 |\n", | |
"| loss | -0.0483 |\n", | |
"| n_updates | 90 |\n", | |
"| policy_gradient_loss | -0.0355 |\n", | |
"| value_loss | 0.00169 |\n", | |
"-----------------------------------------\n", | |
"Model has been saved.\n", | |
"Finished training on connect_four_v3.\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"# Evaluation/training hyperparameter notes:\n", | |
"# 10k steps: Winrate: 0.76, loss order of 1e-03\n", | |
"# 20k steps: Winrate: 0.86, loss order of 1e-04\n", | |
"# 40k steps: Winrate: 0.86, loss order of 7e-06\n", | |
"train_action_mask(env_fn, steps=20_480, seed=0, **env_kwargs)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Starting evaluation vs a random agent. Trained agent will play as player_1.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/luis/dev/ai-guild/.venv/lib/python3.11/site-packages/stable_baselines3/common/save_util.py:166: UserWarning: Could not deserialize object clip_range. Consider using `custom_objects` argument to replace this object.\n", | |
"Exception: code() argument 13 must be str, not int\n", | |
" warnings.warn(\n", | |
"/Users/luis/dev/ai-guild/.venv/lib/python3.11/site-packages/stable_baselines3/common/save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n", | |
"Exception: code() argument 13 must be str, not int\n", | |
" warnings.warn(\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Rewards by round: [{'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}]\n", | |
"Total rewards (incl. negative rewards): {'player_0': 62, 'player_1': -62}\n", | |
"Winrate: 0.19\n", | |
"Final scores: {'player_0': 81, 'player_1': 19}\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"([{'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': 1, 'player_1': -1},\n", | |
" {'player_0': -1, 'player_1': 1},\n", | |
" {'player_0': 1, 'player_1': -1}],\n", | |
" {'player_0': 62, 'player_1': -62},\n", | |
" 0.19,\n", | |
" {'player_0': 81, 'player_1': 19})" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Starting evaluation vs a random agent. Trained agent will play as player_1.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"2023-10-24 17:51:23.930 Python[49699:10381252] ApplePersistenceIgnoreState: Existing state will not be touched. New state will be written to /var/folders/ds/ybrs4jpx7j16cd0skpzk_xfw0000gn/T/org.python.python.savedState\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Rewards by round: [{'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}]\n", | |
"Total rewards (incl. negative rewards): {'player_0': 2, 'player_1': -2}\n", | |
"Winrate: 0.0\n", | |
"Final scores: {'player_0': 2, 'player_1': 0}\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"([{'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}],\n", | |
" {'player_0': 2, 'player_1': -2},\n", | |
" 0.0,\n", | |
" {'player_0': 2, 'player_1': 0})" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"eval_action_mask(env_fn, num_games=2, render_mode=\"human\", **env_kwargs)\n" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".venv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment