Skip to content

Instantly share code, notes, and snippets.

@benblack769
Last active July 17, 2021 19:52
Show Gist options
  • Save benblack769/80fe3ea5637108bf4e63d94de53e28b1 to your computer and use it in GitHub Desktop.
Save benblack769/80fe3ea5637108bf4e63d94de53e28b1 to your computer and use it in GitHub Desktop.
reproducable bug in tianshou
import os
import torch
import argparse
import numpy as np
from copy import deepcopy
from typing import Optional, Tuple
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import BasePolicy, DQNPolicy, RandomPolicy, \
MultiAgentPolicyManager
from tianshou.env import MultiAgentEnv
import numpy as np
import gym
def one_hot(num, size):
arr = np.zeros(size, dtype=np.float32)
arr[num] = 1
return arr
class RepeatOpponentsAction(MultiAgentEnv):
'''
agent's action should be their opponent's action + 5
But their reward is delayed.
How does this end up working. Usually, you would expect that
start action = 0
agent | action | reward_1 | reward_2
--- | ----- | -------- | --------
1 | 2 | 0 | 0
2 | 7 | 0 | 1
1 | 12 | 1 | 0
2 | 1 | 0 | 1
And in this setup, the agent learns fine.
However, it should be perfectly valid to delay the reward one turn:
agent | action | reward_1 | reward_2
--- | ----- | -------- | --------
1 | 2 | 0 | 0
2 | 7 | 0 | 0
1 | 12 | 0 | 1
2 | 1 | 1 | 0
And the agent should still learn. However, Tianshou's MultiAgentPolicyManager
is unable to learn this environment at all.
'''
def __init__(self, num_actions=16):
self.num_actions = num_actions
self.action_space = gym.spaces.Discrete(num_actions)
self.observation_space = gym.spaces.Box(
low=0.0, high=1.0, shape=(num_actions*2,), dtype=np.float32)
self.mask = np.ones(num_actions,dtype=np.bool)
def observe(self):
# returns your opponents previous 2 actions
return np.concatenate([
one_hot(self.prev_actions[-3], self.num_actions),
one_hot(self.prev_actions[-1], self.num_actions)
], axis=0)
def reset(self):
self.prev_actions = [0,0,0,0]
self.current_agent = 1
self.num_steps = 0
self.prev_reward = 0
return {
'agent_id': self.current_agent,
'obs': self.observe(),
'mask': self.mask,
}
def step(self, action):
reward = int(action == (self.prev_actions[-1] + 5) % self.num_actions)
done = self.num_steps >= 9
self.prev_actions.append(action)
del self.prev_actions[0]
self.current_agent = (self.current_agent) % 2 + 1
vec_rew = np.zeros(2, dtype=np.float32)
vec_rew[self.current_agent - 1] = self.prev_reward
self.num_steps += 1
self.prev_reward = reward
obs = {
'agent_id': self.current_agent,
'obs': self.observe(),
'mask': self.mask,
}
return obs, vec_rew, np.array(done), {}
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9,
help='a smaller gamma favors earlier win')
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=5000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128])
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.1)
parser.add_argument('--board-size', type=int, default=6)
parser.add_argument('--win-size', type=int, default=4)
parser.add_argument('--win-rate', type=float, default=0.9,
help='the expected winning rate')
parser.add_argument('--watch', default=False, action='store_true',
help='no training, '
'watch the play of pre-trained models')
parser.add_argument('--agent-id', type=int, default=2,
help='the learned agent plays as the'
' agent_id-th player. Choices are 1 and 2.')
parser.add_argument('--resume-path', type=str, default='',
help='the path of agent pth file '
'for resuming from a pre-trained agent')
parser.add_argument('--opponent-path', type=str, default='',
help='the path of opponent agent pth file '
'for resuming from a pre-trained agent')
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
return parser
def get_args() -> argparse.Namespace:
parser = get_parser()
return parser.parse_known_args()[0]
def get_agents(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer]:
env = RepeatOpponentsAction()
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
if agent_learn is None:
# model
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device
).to(args.device)
if optim is None:
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
agent_learn = DQNPolicy(
net, optim, args.gamma, args.n_step,
target_update_freq=args.target_update_freq)
if args.resume_path:
agent_learn.load_state_dict(torch.load(args.resume_path))
if agent_opponent is None:
if args.opponent_path:
agent_opponent = deepcopy(agent_learn)
agent_opponent.load_state_dict(torch.load(args.opponent_path))
else:
agent_opponent = RandomPolicy()
if args.agent_id == 1:
agents = [agent_learn, agent_opponent]
else:
agents = [agent_opponent, agent_learn]
policy = MultiAgentPolicyManager(agents)
return policy, optim
def train_agent(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[dict, BasePolicy]:
def env_func():
return RepeatOpponentsAction()
train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)])
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
policy, optim = get_agents(
args, agent_learn=agent_learn,
agent_opponent=agent_opponent, optim=optim)
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy):
if hasattr(args, 'model_save_path'):
model_save_path = args.model_save_path
else:
model_save_path = os.path.join(
args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth')
torch.save(
policy.policies[args.agent_id - 1].state_dict(),
model_save_path)
def stop_fn(mean_rewards):
return mean_rewards >= args.win_rate
def train_fn(epoch, env_step):
policy.policies[args.agent_id - 1].set_eps(args.eps_train)
def test_fn(epoch, env_step):
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
def reward_metric(rews):
return rews[:, args.agent_id - 1]
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
logger=logger, test_in_train=False, reward_metric=reward_metric)
return result, policy.policies[args.agent_id - 1]
def watch(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
) -> None:
env =RepeatOpponentsAction()
policy, optim = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent)
policy.eval()
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=100, render=False)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")
import pprint
args=get_args()
if args.watch:
watch(args)
result, agent = train_agent(args)
assert result["best_reward"] >= args.win_rate
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
watch(args, agent)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment