Last active
July 17, 2021 19:52
-
-
Save benblack769/80fe3ea5637108bf4e63d94de53e28b1 to your computer and use it in GitHub Desktop.
reproducable bug in tianshou
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import argparse | |
import numpy as np | |
from copy import deepcopy | |
from typing import Optional, Tuple | |
from torch.utils.tensorboard import SummaryWriter | |
from tianshou.utils import BasicLogger | |
from tianshou.env import DummyVectorEnv | |
from tianshou.utils.net.common import Net | |
from tianshou.trainer import offpolicy_trainer | |
from tianshou.data import Collector, VectorReplayBuffer | |
from tianshou.policy import BasePolicy, DQNPolicy, RandomPolicy, \ | |
MultiAgentPolicyManager | |
from tianshou.env import MultiAgentEnv | |
import numpy as np | |
import gym | |
def one_hot(num, size): | |
arr = np.zeros(size, dtype=np.float32) | |
arr[num] = 1 | |
return arr | |
class RepeatOpponentsAction(MultiAgentEnv): | |
''' | |
agent's action should be their opponent's action + 5 | |
But their reward is delayed. | |
How does this end up working. Usually, you would expect that | |
start action = 0 | |
agent | action | reward_1 | reward_2 | |
--- | ----- | -------- | -------- | |
1 | 2 | 0 | 0 | |
2 | 7 | 0 | 1 | |
1 | 12 | 1 | 0 | |
2 | 1 | 0 | 1 | |
And in this setup, the agent learns fine. | |
However, it should be perfectly valid to delay the reward one turn: | |
agent | action | reward_1 | reward_2 | |
--- | ----- | -------- | -------- | |
1 | 2 | 0 | 0 | |
2 | 7 | 0 | 0 | |
1 | 12 | 0 | 1 | |
2 | 1 | 1 | 0 | |
And the agent should still learn. However, Tianshou's MultiAgentPolicyManager | |
is unable to learn this environment at all. | |
''' | |
def __init__(self, num_actions=16): | |
self.num_actions = num_actions | |
self.action_space = gym.spaces.Discrete(num_actions) | |
self.observation_space = gym.spaces.Box( | |
low=0.0, high=1.0, shape=(num_actions*2,), dtype=np.float32) | |
self.mask = np.ones(num_actions,dtype=np.bool) | |
def observe(self): | |
# returns your opponents previous 2 actions | |
return np.concatenate([ | |
one_hot(self.prev_actions[-3], self.num_actions), | |
one_hot(self.prev_actions[-1], self.num_actions) | |
], axis=0) | |
def reset(self): | |
self.prev_actions = [0,0,0,0] | |
self.current_agent = 1 | |
self.num_steps = 0 | |
self.prev_reward = 0 | |
return { | |
'agent_id': self.current_agent, | |
'obs': self.observe(), | |
'mask': self.mask, | |
} | |
def step(self, action): | |
reward = int(action == (self.prev_actions[-1] + 5) % self.num_actions) | |
done = self.num_steps >= 9 | |
self.prev_actions.append(action) | |
del self.prev_actions[0] | |
self.current_agent = (self.current_agent) % 2 + 1 | |
vec_rew = np.zeros(2, dtype=np.float32) | |
vec_rew[self.current_agent - 1] = self.prev_reward | |
self.num_steps += 1 | |
self.prev_reward = reward | |
obs = { | |
'agent_id': self.current_agent, | |
'obs': self.observe(), | |
'mask': self.mask, | |
} | |
return obs, vec_rew, np.array(done), {} | |
def get_parser() -> argparse.ArgumentParser: | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--seed', type=int, default=1626) | |
parser.add_argument('--eps-test', type=float, default=0.05) | |
parser.add_argument('--eps-train', type=float, default=0.1) | |
parser.add_argument('--buffer-size', type=int, default=20000) | |
parser.add_argument('--lr', type=float, default=1e-3) | |
parser.add_argument('--gamma', type=float, default=0.9, | |
help='a smaller gamma favors earlier win') | |
parser.add_argument('--n-step', type=int, default=3) | |
parser.add_argument('--target-update-freq', type=int, default=320) | |
parser.add_argument('--epoch', type=int, default=20) | |
parser.add_argument('--step-per-epoch', type=int, default=5000) | |
parser.add_argument('--step-per-collect', type=int, default=10) | |
parser.add_argument('--update-per-step', type=float, default=0.1) | |
parser.add_argument('--batch-size', type=int, default=64) | |
parser.add_argument('--hidden-sizes', type=int, | |
nargs='*', default=[128, 128, 128, 128]) | |
parser.add_argument('--training-num', type=int, default=10) | |
parser.add_argument('--test-num', type=int, default=100) | |
parser.add_argument('--logdir', type=str, default='log') | |
parser.add_argument('--render', type=float, default=0.1) | |
parser.add_argument('--board-size', type=int, default=6) | |
parser.add_argument('--win-size', type=int, default=4) | |
parser.add_argument('--win-rate', type=float, default=0.9, | |
help='the expected winning rate') | |
parser.add_argument('--watch', default=False, action='store_true', | |
help='no training, ' | |
'watch the play of pre-trained models') | |
parser.add_argument('--agent-id', type=int, default=2, | |
help='the learned agent plays as the' | |
' agent_id-th player. Choices are 1 and 2.') | |
parser.add_argument('--resume-path', type=str, default='', | |
help='the path of agent pth file ' | |
'for resuming from a pre-trained agent') | |
parser.add_argument('--opponent-path', type=str, default='', | |
help='the path of opponent agent pth file ' | |
'for resuming from a pre-trained agent') | |
parser.add_argument( | |
'--device', type=str, | |
default='cuda' if torch.cuda.is_available() else 'cpu') | |
return parser | |
def get_args() -> argparse.Namespace: | |
parser = get_parser() | |
return parser.parse_known_args()[0] | |
def get_agents( | |
args: argparse.Namespace = get_args(), | |
agent_learn: Optional[BasePolicy] = None, | |
agent_opponent: Optional[BasePolicy] = None, | |
optim: Optional[torch.optim.Optimizer] = None, | |
) -> Tuple[BasePolicy, torch.optim.Optimizer]: | |
env = RepeatOpponentsAction() | |
args.state_shape = env.observation_space.shape or env.observation_space.n | |
args.action_shape = env.action_space.shape or env.action_space.n | |
if agent_learn is None: | |
# model | |
net = Net(args.state_shape, args.action_shape, | |
hidden_sizes=args.hidden_sizes, device=args.device | |
).to(args.device) | |
if optim is None: | |
optim = torch.optim.Adam(net.parameters(), lr=args.lr) | |
agent_learn = DQNPolicy( | |
net, optim, args.gamma, args.n_step, | |
target_update_freq=args.target_update_freq) | |
if args.resume_path: | |
agent_learn.load_state_dict(torch.load(args.resume_path)) | |
if agent_opponent is None: | |
if args.opponent_path: | |
agent_opponent = deepcopy(agent_learn) | |
agent_opponent.load_state_dict(torch.load(args.opponent_path)) | |
else: | |
agent_opponent = RandomPolicy() | |
if args.agent_id == 1: | |
agents = [agent_learn, agent_opponent] | |
else: | |
agents = [agent_opponent, agent_learn] | |
policy = MultiAgentPolicyManager(agents) | |
return policy, optim | |
def train_agent( | |
args: argparse.Namespace = get_args(), | |
agent_learn: Optional[BasePolicy] = None, | |
agent_opponent: Optional[BasePolicy] = None, | |
optim: Optional[torch.optim.Optimizer] = None, | |
) -> Tuple[dict, BasePolicy]: | |
def env_func(): | |
return RepeatOpponentsAction() | |
train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)]) | |
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)]) | |
# seed | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
train_envs.seed(args.seed) | |
test_envs.seed(args.seed) | |
policy, optim = get_agents( | |
args, agent_learn=agent_learn, | |
agent_opponent=agent_opponent, optim=optim) | |
# collector | |
train_collector = Collector( | |
policy, train_envs, | |
VectorReplayBuffer(args.buffer_size, len(train_envs)), | |
exploration_noise=True) | |
test_collector = Collector(policy, test_envs, exploration_noise=True) | |
# policy.set_eps(1) | |
train_collector.collect(n_step=args.batch_size * args.training_num) | |
# log | |
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') | |
writer = SummaryWriter(log_path) | |
writer.add_text("args", str(args)) | |
logger = BasicLogger(writer) | |
def save_fn(policy): | |
if hasattr(args, 'model_save_path'): | |
model_save_path = args.model_save_path | |
else: | |
model_save_path = os.path.join( | |
args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth') | |
torch.save( | |
policy.policies[args.agent_id - 1].state_dict(), | |
model_save_path) | |
def stop_fn(mean_rewards): | |
return mean_rewards >= args.win_rate | |
def train_fn(epoch, env_step): | |
policy.policies[args.agent_id - 1].set_eps(args.eps_train) | |
def test_fn(epoch, env_step): | |
policy.policies[args.agent_id - 1].set_eps(args.eps_test) | |
def reward_metric(rews): | |
return rews[:, args.agent_id - 1] | |
# trainer | |
result = offpolicy_trainer( | |
policy, train_collector, test_collector, args.epoch, | |
args.step_per_epoch, args.step_per_collect, args.test_num, | |
args.batch_size, train_fn=train_fn, test_fn=test_fn, | |
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, | |
logger=logger, test_in_train=False, reward_metric=reward_metric) | |
return result, policy.policies[args.agent_id - 1] | |
def watch( | |
args: argparse.Namespace = get_args(), | |
agent_learn: Optional[BasePolicy] = None, | |
agent_opponent: Optional[BasePolicy] = None, | |
) -> None: | |
env =RepeatOpponentsAction() | |
policy, optim = get_agents( | |
args, agent_learn=agent_learn, agent_opponent=agent_opponent) | |
policy.eval() | |
policy.policies[args.agent_id - 1].set_eps(args.eps_test) | |
collector = Collector(policy, env, exploration_noise=True) | |
result = collector.collect(n_episode=100, render=False) | |
rews, lens = result["rews"], result["lens"] | |
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") | |
import pprint | |
args=get_args() | |
if args.watch: | |
watch(args) | |
result, agent = train_agent(args) | |
assert result["best_reward"] >= args.win_rate | |
if __name__ == '__main__': | |
pprint.pprint(result) | |
# Let's watch its performance! | |
watch(args, agent) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment