Skip to content

Instantly share code, notes, and snippets.

@rdednl
Last active September 7, 2023 15:01
Show Gist options
  • Save rdednl/64e8fb4b7d4a0e4d047f91188cbfaaed to your computer and use it in GitHub Desktop.
Save rdednl/64e8fb4b7d4a0e4d047f91188cbfaaed to your computer and use it in GitHub Desktop.
batch norm is bad (td3/sac)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "registered-packet",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import gym\n",
"from gym import spaces"
]
},
{
"cell_type": "markdown",
"id": "23febdcb",
"metadata": {},
"source": [
"# This simple environment doesn't do anything. The reward is just the action. Quite simple to learn right?"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "completed-private",
"metadata": {},
"outputs": [],
"source": [
"class RandomEnv(gym.Env):\n",
" def __init__(self, reward_func=lambda a: a):\n",
" super(RandomEnv, self).__init__()\n",
" self.num_steps = 0\n",
" self.max_steps = 30\n",
" self.reward_func = reward_func\n",
" \n",
" self.action_space = spaces.Box(low=-1, high=+1, shape=(1,))\n",
" self.observation_space = spaces.Box(low=-1, high=+1, shape=(3,))\n",
"\n",
" def reset(self):\n",
" self.num_steps = 0\n",
" return self.get_state()\n",
"\n",
" def step(self, action):\n",
" self.num_steps += 1\n",
" action = action[0]\n",
" return self.get_state(), self.reward_func(action), self.get_terminal(), {}\n",
"\n",
" def get_state(self):\n",
" return np.array([0.,0.,0.])\n",
"# return self.observation_space.sample()\n",
"\n",
" def get_terminal(self):\n",
" return True if self.num_steps >= self.max_steps else False"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8f987136",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import collections\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.distributions import Normal\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "markdown",
"id": "6d7cac66",
"metadata": {},
"source": [
"# Let's define a replay buffer, TD3 and SAC training code. Both of them with two kinds of architecture, one with batch norm, the other one without it..."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b83074b5",
"metadata": {},
"outputs": [],
"source": [
"# REPLAY BUFFER\n",
"Transition = collections.namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))\n",
"\n",
"class ReplayBuffer():\n",
" def __init__(self, max_size=10000):\n",
" self.data = collections.deque([], maxlen=max_size)\n",
"\n",
" def push(self, state, action, next_state, reward, done):\n",
" self.data.append(Transition(state, action, next_state, reward, done))\n",
"\n",
" def sample(self, batch_size):\n",
" transitions = random.sample(self.data, batch_size)\n",
"\n",
" batch = Transition(*zip(*transitions))\n",
" state = torch.tensor(batch.state, device=device, dtype=torch.float32)\n",
" next_state = torch.tensor(batch.next_state, device=device, dtype=torch.float32)\n",
" action = torch.tensor(batch.action, device=device)\n",
" reward = torch.tensor(batch.reward, device=device)\n",
" done = torch.tensor(batch.done, device=device)\n",
" return state, action, next_state, reward, done\n",
"\n",
" def __len__(self):\n",
" return len(self.data)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "355ef3c4",
"metadata": {},
"outputs": [],
"source": [
"class TrainerBase:\n",
" def __init__(self,\n",
" policy_class,\n",
" batch_size=256,\n",
" random_exploration=10,\n",
" gamma=0.99,\n",
" target_update_factor=0.005,\n",
" target_update_interval=1,\n",
" test_episodes=5,\n",
" test_interval=10,\n",
" print_interval=-1,\n",
" num_episodes=100,\n",
" verbose=0\n",
" ):\n",
" self.batch_size = batch_size\n",
" self.random_exploration = random_exploration\n",
" self.gamma = gamma\n",
" self.target_update_factor=target_update_factor\n",
" self.target_update_interval=target_update_interval\n",
" self.test_episodes=test_episodes\n",
" self.test_interval=test_interval\n",
" self.print_interval=print_interval\n",
" self.num_episodes=num_episodes\n",
" self.verbose = verbose\n",
" \n",
" self.env = RandomEnv()\n",
" self.test_env = RandomEnv()\n",
"\n",
" self.model = policy_class(self.env.observation_space.shape[0], self.env.action_space.shape[0])\n",
" self.target = policy_class(self.env.observation_space.shape[0], self.env.action_space.shape[0])\n",
" self.target.load_state_dict(self.model.state_dict())\n",
" self.target.eval()\n",
"\n",
" self.model.to(device)\n",
" self.target.to(device)\n",
"\n",
" self.memory = ReplayBuffer()\n",
" \n",
" self.training_steps = 0\n",
" self.episode_num = 0\n",
" \n",
" self.actor_optimizer = torch.optim.Adam(self.model.actor.parameters(), lr=3e-4)\n",
" self.critic_optimizer_1 = torch.optim.Adam(self.model.critic_1.parameters(), lr=3e-4)\n",
" self.critic_optimizer_2 = torch.optim.Adam(self.model.critic_2.parameters(), lr=3e-4)\n",
" \n",
" self.run_stats = []\n",
"\n",
" def select_action(self, state, test_mode=False):\n",
" raise NotImplementedError()\n",
"\n",
" def optimize_model(self):\n",
" raise NotImplementedError()\n",
"\n",
" def test(self):\n",
" test_rewards = []\n",
" for i in range(self.test_episodes):\n",
" state = self.test_env.reset()\n",
" episode_total_reward = 0\n",
" done = False\n",
" while not done:\n",
" with torch.no_grad():\n",
" action = self.select_action(state, test_mode=True)\n",
" next_state, reward, done, info = self.test_env.step(action)\n",
" episode_total_reward += reward\n",
" state = next_state\n",
" test_rewards.append(episode_total_reward)\n",
" score = sum(test_rewards)/self.test_episodes\n",
" self.run_stats.append((self.training_steps, score))\n",
" print(f\"[TESTING] [Total Steps: {self.training_steps}] [Average Reward {score:.3f}]\")\n",
" \n",
" def train(self):\n",
" self.test()\n",
" for self.episode_num in range(1, self.num_episodes+1):\n",
" episode_total_reward = 0\n",
" state = self.env.reset()\n",
" done = False\n",
" while not done:\n",
" action = self.select_action(state)\n",
" next_state, reward, done, info = self.env.step(action)\n",
" self.training_steps += 1\n",
" episode_total_reward += reward\n",
" self.memory.push(state, action, next_state, reward, int(done))\n",
" state = next_state\n",
"\n",
" if self.training_steps > self.random_exploration:\n",
" self.optimize_model()\n",
" \n",
" if self.verbose and self.episode_num % self.print_interval == 0:\n",
" print(f'[Episode {self.episode_num}][Reward {episode_total_reward:.3f}]')\n",
" \n",
" if self.episode_num % self.test_interval == 0:\n",
" self.test()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "85acc97a",
"metadata": {},
"outputs": [],
"source": [
"class SAC(TrainerBase):\n",
" def __init__(self,\n",
" policy_class,\n",
" batch_size=256,\n",
" random_exploration=10,\n",
" gamma=0.99,\n",
" learnable_temperature=True,\n",
" init_temperature=0.1,\n",
" target_update_factor=0.005,\n",
" target_update_interval=1,\n",
" test_episodes=5,\n",
" test_interval=10,\n",
" print_interval=1,\n",
" num_episodes=100,\n",
" verbose=0\n",
" ):\n",
" \n",
" super().__init__(\n",
" policy_class,\n",
" batch_size,\n",
" random_exploration,\n",
" gamma,\n",
" target_update_factor,\n",
" target_update_interval,\n",
" test_episodes,\n",
" test_interval,\n",
" print_interval,\n",
" num_episodes,\n",
" verbose\n",
" )\n",
" \n",
" if learnable_temperature:\n",
" self.log_alpha = torch.log(torch.tensor(init_temperature, device=device))\n",
" self.log_alpha.requires_grad = True\n",
" self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=3e-4)\n",
" else:\n",
" self.log_alpha = torch.tensor(init_temperature, device=device)\n",
" self.log_alpha_optimizer = None\n",
"\n",
" self.target_entropy = -self.env.action_space.shape[0]\n",
"\n",
" def select_action(self, state, test_mode=False):\n",
" if self.training_steps < self.random_exploration:\n",
" return self.env.action_space.sample()\n",
" \n",
" state = torch.tensor([state], device=device, dtype=torch.float32)\n",
" with torch.no_grad():\n",
" action = self.model.select_action(state, test_mode=test_mode)\n",
"\n",
" return torch.clamp(action, -1, 1).cpu().numpy()[0]\n",
"\n",
" def optimize_model(self):\n",
" if self.batch_size > len(self.memory):\n",
" batch_size = len(self.memory)\n",
" else:\n",
" batch_size = self.batch_size\n",
"\n",
" state, action, next_state, reward, done = self.memory.sample(batch_size)\n",
"\n",
" q_1 = self.model.critic_1(state, action).squeeze(-1)\n",
" q_2 = self.model.critic_2(state, action).squeeze(-1)\n",
"\n",
" with torch.no_grad():\n",
" next_action, next_log_pi = self.model.sample_action(next_state)\n",
" value_function_1 = self.target.critic_1(next_state, next_action).squeeze(-1)\n",
" value_function_2 = self.target.critic_2(next_state, next_action).squeeze(-1)\n",
" value_function = torch.min(value_function_1, value_function_2) - self.log_alpha.exp().detach() * next_log_pi\n",
" next_value = (1 - done).float() * value_function\n",
" q_target = reward + self.gamma * next_value\n",
"\n",
" critic_loss_1 = (q_1 - q_target).pow(2).mean()\n",
" critic_loss_2 = (q_2 - q_target).pow(2).mean()\n",
"\n",
" self.critic_optimizer_1.zero_grad()\n",
" critic_loss_1.backward()\n",
" self.critic_optimizer_1.step()\n",
"\n",
" self.critic_optimizer_2.zero_grad()\n",
" critic_loss_2.backward()\n",
" self.critic_optimizer_2.step()\n",
"\n",
" action, log_pi = self.model.sample_action(state)\n",
" q_1 = self.model.critic_1(state, action)\n",
" q_2 = self.model.critic_2(state, action)\n",
"\n",
" actor_loss = ((self.log_alpha.exp().detach() * log_pi) - torch.min(q_1, q_2)).mean()\n",
" self.actor_optimizer.zero_grad()\n",
" actor_loss.backward()\n",
" self.actor_optimizer.step()\n",
"\n",
" if self.log_alpha_optimizer:\n",
" self.log_alpha_optimizer.zero_grad()\n",
" alpha_loss = (self.log_alpha * (-log_pi - self.target_entropy).detach()).mean()\n",
" alpha_loss.backward()\n",
" self.log_alpha_optimizer.step()\n",
" \n",
" if self.training_steps % self.target_update_interval == 0:\n",
" with torch.no_grad():\n",
" for model_param, target_param in zip(self.model.parameters(), self.target.parameters()):\n",
" target_param.mul_(1 - self.target_update_factor)\n",
" target_param.add_(self.target_update_factor * model_param.data)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c77f15d4",
"metadata": {},
"outputs": [],
"source": [
"class TD3(TrainerBase):\n",
" def __init__(self,\n",
" policy_class,\n",
" batch_size=256,\n",
" random_exploration=10,\n",
" gamma=0.99,\n",
" actor_update_interval=2,\n",
" exploration_noise=0.1,\n",
" target_update_factor=0.005,\n",
" target_update_interval=1,\n",
" test_episodes=5,\n",
" test_interval=10,\n",
" print_interval=1,\n",
" num_episodes=100,\n",
" verbose=0\n",
" ):\n",
" \n",
" super().__init__(\n",
" policy_class,\n",
" batch_size,\n",
" random_exploration,\n",
" gamma,\n",
" target_update_factor,\n",
" target_update_interval,\n",
" test_episodes,\n",
" test_interval,\n",
" print_interval,\n",
" num_episodes,\n",
" verbose\n",
" )\n",
" \n",
" self.actor_update_interval = actor_update_interval\n",
" self.exploration_noise = exploration_noise\n",
"\n",
" def select_action(self, state, test_mode=False):\n",
" if self.training_steps < self.random_exploration:\n",
" return self.env.action_space.sample()\n",
"\n",
" state = torch.tensor([state], device=device, dtype=torch.float32)\n",
" with torch.no_grad():\n",
" action = self.model.select_action(state)\n",
" \n",
" if not test_mode:\n",
" action += torch.randn_like(action) * self.exploration_noise\n",
" action = torch.clamp(action, min=-1, max=+1)\n",
" \n",
" return action.cpu().numpy()[0]\n",
"\n",
" def optimize_model(self):\n",
" if self.batch_size > len(self.memory):\n",
" batch_size = len(self.memory)\n",
" else:\n",
" batch_size = self.batch_size\n",
"\n",
" state, action, next_state, reward, done = self.memory.sample(batch_size)\n",
"\n",
" q_1 = self.model.critic_1(state, action).squeeze(-1)\n",
" q_2 = self.model.critic_2(state, action).squeeze(-1)\n",
" \n",
" with torch.no_grad():\n",
" noise = (torch.randn_like(action) * 0.2).clamp(-0.5, 0.5)\n",
" next_action = (self.target.actor(next_state) + noise).clamp(-1.0, 1.0)\n",
" target_q_1 = self.target.critic_1(next_state, next_action).squeeze(-1)\n",
" target_q_2 = self.target.critic_2(next_state, next_action).squeeze(-1)\n",
" target_q = torch.min(target_q_1, target_q_2)\n",
" target_q = reward + self.gamma * (1 - done).float() * target_q\n",
"\n",
" critic_loss_1 = (q_1 - target_q).pow(2).mean()\n",
" critic_loss_2 = (q_2 - target_q).pow(2).mean()\n",
"\n",
" self.critic_optimizer_1.zero_grad()\n",
" critic_loss_1.backward()\n",
" self.critic_optimizer_1.step()\n",
"\n",
" self.critic_optimizer_2.zero_grad()\n",
" critic_loss_2.backward()\n",
" self.critic_optimizer_2.step()\n",
"\n",
" if self.training_steps % self.actor_update_interval == 0:\n",
" actor_loss = - torch.mean(self.model.critic_1(state, self.model.actor(state)))\n",
" \n",
" self.actor_optimizer.zero_grad()\n",
" actor_loss.backward()\n",
" self.actor_optimizer.step()\n",
" \n",
" with torch.no_grad():\n",
" for model_param, target_param in zip(self.model.parameters(), self.target.parameters()):\n",
" target_param.mul_(1 - self.target_update_factor)\n",
" target_param.add_(self.target_update_factor * model_param.data)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e098c955",
"metadata": {},
"outputs": [],
"source": [
"class SACActor(nn.Module):\n",
" def __init__(self, state_size, actions, min_log=-20, max_log=2, use_batch_norm=False):\n",
" super(SACActor, self).__init__()\n",
" if use_batch_norm:\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(state_size, 64),\n",
" nn.BatchNorm1d(64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, 64),\n",
" nn.BatchNorm1d(64),\n",
" nn.ReLU()\n",
" )\n",
" else:\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(state_size, 64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, 64),\n",
" nn.ReLU()\n",
" )\n",
" \n",
" self.mean = nn.Linear(64, actions)\n",
" self.log_std = nn.Linear(64, actions)\n",
"\n",
" self.min_log = min_log\n",
" self.max_log = max_log\n",
"\n",
" def forward(self, state):\n",
" out = self.layers(state)\n",
" mean = self.mean(out)\n",
" log_std = self.log_std(out)\n",
" log_std = torch.clamp(log_std, min=self.min_log, max=self.max_log)\n",
" return mean, log_std\n",
"\n",
"class SACCritic(nn.Module):\n",
" def __init__(self, state_size, actions, use_batch_norm=False):\n",
" super(SACCritic, self).__init__()\n",
" if use_batch_norm:\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(state_size + actions, 64),\n",
" nn.BatchNorm1d(64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, 64),\n",
" nn.BatchNorm1d(64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, 1)\n",
" )\n",
" else:\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(state_size + actions, 64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, 64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, 1)\n",
" )\n",
" \n",
" def forward(self, state, action):\n",
" return self.layers(torch.cat([state, action], dim=1))\n",
"\n",
"class SACNet(nn.Module):\n",
" def __init__(self, state_size, actions, use_batch_norm=False):\n",
" super(SACNet, self).__init__()\n",
" self.actor = SACActor(state_size, actions, use_batch_norm)\n",
" self.critic_1 = SACCritic(state_size, actions, use_batch_norm)\n",
" self.critic_2 = SACCritic(state_size, actions, use_batch_norm)\n",
"\n",
" def select_action(self, state, test_mode=False):\n",
" self.eval()\n",
" mean, log_std = self.actor(state)\n",
" if test_mode:\n",
" action = torch.tanh(mean)\n",
" else:\n",
" normal = Normal(mean, log_std.exp())\n",
" action = torch.tanh(normal.rsample())\n",
" self.train()\n",
" return action\n",
"\n",
" def sample_action(self, state):\n",
" mean, log_std = self.actor(state)\n",
" std = log_std.exp()\n",
"\n",
" normal = Normal(mean, std)\n",
" xi = normal.rsample()\n",
" action = torch.tanh(xi)\n",
" log_pi = normal.log_prob(xi)\n",
"\n",
" log_pi -= torch.log(1 - action.pow(2) + 1e-6)\n",
" log_pi = log_pi.sum(1, keepdim=True)\n",
"\n",
" return action, log_pi"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "b3efb765",
"metadata": {},
"outputs": [],
"source": [
"class TD3Actor(nn.Module):\n",
" def __init__(self, state_size, actions, use_batch_norm=False):\n",
" super(TD3Actor, self).__init__()\n",
" if use_batch_norm:\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(state_size, 64),\n",
" nn.BatchNorm1d(64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, 64),\n",
" nn.BatchNorm1d(64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, actions),\n",
" nn.Tanh(),\n",
" )\n",
" else:\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(state_size, 64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, 64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, actions),\n",
" nn.Tanh(),\n",
" )\n",
"\n",
" def forward(self, state):\n",
" return self.layers(state)\n",
"\n",
"class TD3Critic(nn.Module):\n",
" def __init__(self, state_size, actions, use_batch_norm=False):\n",
" super(TD3Critic, self).__init__()\n",
" if use_batch_norm:\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(state_size + actions, 64),\n",
" nn.BatchNorm1d(64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, 64),\n",
" nn.BatchNorm1d(64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, 1),\n",
" )\n",
" else:\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(state_size + actions, 64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, 64),\n",
" nn.ReLU(),\n",
" nn.Linear(64, 1),\n",
" )\n",
" \n",
" def forward(self, state, action):\n",
" return self.layers(torch.cat([state, action], dim=1))\n",
"\n",
"class TD3Net(nn.Module):\n",
" def __init__(self, state_size, actions, use_batch_norm=True):\n",
" super(TD3Net, self).__init__()\n",
" self.actor = TD3Actor(state_size, actions, use_batch_norm)\n",
" self.critic_1 = TD3Critic(state_size, actions, use_batch_norm)\n",
" self.critic_2 = TD3Critic(state_size, actions, use_batch_norm)\n",
"\n",
" def select_action(self, state):\n",
" self.eval()\n",
" a = self.actor.forward(state)\n",
" self.train()\n",
" return a"
]
},
{
"cell_type": "markdown",
"id": "caaf501c",
"metadata": {},
"source": [
"# Let's run TD3 with this simple environment with the architecture with batch norm"
]
},
{
"cell_type": "markdown",
"id": "cc9d7c1f",
"metadata": {},
"source": [
"## All these are run for 5 times, just to avoid saying \"oh but it's a random seed...\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "fcabc894",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------\n",
"0\n",
"[TESTING] [Total Steps: 0] [Average Reward 0.725]\n",
"[TESTING] [Total Steps: 300] [Average Reward 2.932]\n",
"[TESTING] [Total Steps: 600] [Average Reward 2.933]\n",
"[TESTING] [Total Steps: 900] [Average Reward 2.933]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 2.933]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 2.933]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 2.933]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 2.933]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 2.933]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 2.933]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 2.933]\n",
"----------\n",
"1\n",
"[TESTING] [Total Steps: 0] [Average Reward -1.444]\n",
"[TESTING] [Total Steps: 300] [Average Reward -3.604]\n",
"[TESTING] [Total Steps: 600] [Average Reward -3.604]\n",
"[TESTING] [Total Steps: 900] [Average Reward -3.604]\n",
"[TESTING] [Total Steps: 1200] [Average Reward -3.604]\n",
"[TESTING] [Total Steps: 1500] [Average Reward -3.604]\n",
"[TESTING] [Total Steps: 1800] [Average Reward -3.604]\n",
"[TESTING] [Total Steps: 2100] [Average Reward -3.604]\n",
"[TESTING] [Total Steps: 2400] [Average Reward -3.604]\n",
"[TESTING] [Total Steps: 2700] [Average Reward -3.604]\n",
"[TESTING] [Total Steps: 3000] [Average Reward -3.604]\n",
"----------\n",
"2\n",
"[TESTING] [Total Steps: 0] [Average Reward 2.215]\n",
"[TESTING] [Total Steps: 300] [Average Reward -0.534]\n",
"[TESTING] [Total Steps: 600] [Average Reward -0.535]\n",
"[TESTING] [Total Steps: 900] [Average Reward -0.535]\n",
"[TESTING] [Total Steps: 1200] [Average Reward -0.535]\n",
"[TESTING] [Total Steps: 1500] [Average Reward -0.535]\n",
"[TESTING] [Total Steps: 1800] [Average Reward -0.535]\n",
"[TESTING] [Total Steps: 2100] [Average Reward -0.535]\n",
"[TESTING] [Total Steps: 2400] [Average Reward -0.535]\n",
"[TESTING] [Total Steps: 2700] [Average Reward -0.535]\n",
"[TESTING] [Total Steps: 3000] [Average Reward -0.535]\n",
"----------\n",
"3\n",
"[TESTING] [Total Steps: 0] [Average Reward -1.536]\n",
"[TESTING] [Total Steps: 300] [Average Reward -0.795]\n",
"[TESTING] [Total Steps: 600] [Average Reward -0.789]\n",
"[TESTING] [Total Steps: 900] [Average Reward -0.789]\n",
"[TESTING] [Total Steps: 1200] [Average Reward -0.789]\n",
"[TESTING] [Total Steps: 1500] [Average Reward -0.789]\n",
"[TESTING] [Total Steps: 1800] [Average Reward -0.789]\n",
"[TESTING] [Total Steps: 2100] [Average Reward -0.789]\n",
"[TESTING] [Total Steps: 2400] [Average Reward -0.789]\n",
"[TESTING] [Total Steps: 2700] [Average Reward -0.789]\n",
"[TESTING] [Total Steps: 3000] [Average Reward -0.789]\n",
"----------\n",
"4\n",
"[TESTING] [Total Steps: 0] [Average Reward 1.792]\n",
"[TESTING] [Total Steps: 300] [Average Reward -0.675]\n",
"[TESTING] [Total Steps: 600] [Average Reward -0.672]\n",
"[TESTING] [Total Steps: 900] [Average Reward -0.672]\n",
"[TESTING] [Total Steps: 1200] [Average Reward -0.672]\n",
"[TESTING] [Total Steps: 1500] [Average Reward -0.672]\n",
"[TESTING] [Total Steps: 1800] [Average Reward -0.672]\n",
"[TESTING] [Total Steps: 2100] [Average Reward -0.672]\n",
"[TESTING] [Total Steps: 2400] [Average Reward -0.672]\n",
"[TESTING] [Total Steps: 2700] [Average Reward -0.672]\n",
"[TESTING] [Total Steps: 3000] [Average Reward -0.672]\n"
]
}
],
"source": [
"td3_stats_batch_norm = []\n",
"for i in range(5):\n",
" print(\"-\" * 10)\n",
" print(i)\n",
" model = TD3(policy_class=lambda s, a: TD3Net(s, a, use_batch_norm=True))\n",
" model.train()\n",
" td3_stats_batch_norm.append(model.run_stats)"
]
},
{
"cell_type": "markdown",
"id": "1b28a4f2",
"metadata": {},
"source": [
"# doesn't work... but if we remove batch norm..."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "47431284",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------\n",
"0\n",
"[TESTING] [Total Steps: 0] [Average Reward -0.037]\n",
"[TESTING] [Total Steps: 300] [Average Reward 19.322]\n",
"[TESTING] [Total Steps: 600] [Average Reward 28.569]\n",
"[TESTING] [Total Steps: 900] [Average Reward 29.586]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 29.817]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 29.900]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 29.938]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 29.959]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 29.971]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 29.979]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 29.993]\n",
"----------\n",
"1\n",
"[TESTING] [Total Steps: 0] [Average Reward 1.559]\n",
"[TESTING] [Total Steps: 300] [Average Reward 26.659]\n",
"[TESTING] [Total Steps: 600] [Average Reward 29.388]\n",
"[TESTING] [Total Steps: 900] [Average Reward 29.765]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 29.877]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 29.925]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 29.950]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 29.964]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 29.974]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 29.980]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 29.984]\n",
"----------\n",
"2\n",
"[TESTING] [Total Steps: 0] [Average Reward -2.776]\n",
"[TESTING] [Total Steps: 300] [Average Reward 24.281]\n",
"[TESTING] [Total Steps: 600] [Average Reward 29.199]\n",
"[TESTING] [Total Steps: 900] [Average Reward 29.748]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 29.921]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 29.964]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 29.980]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 29.988]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 29.992]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 29.994]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 29.996]\n",
"----------\n",
"3\n",
"[TESTING] [Total Steps: 0] [Average Reward 0.407]\n",
"[TESTING] [Total Steps: 300] [Average Reward 27.670]\n",
"[TESTING] [Total Steps: 600] [Average Reward 29.705]\n",
"[TESTING] [Total Steps: 900] [Average Reward 29.894]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 29.948]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 29.970]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 29.981]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 29.987]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 29.990]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 29.992]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 29.994]\n",
"----------\n",
"4\n",
"[TESTING] [Total Steps: 0] [Average Reward 1.229]\n",
"[TESTING] [Total Steps: 300] [Average Reward 29.227]\n",
"[TESTING] [Total Steps: 600] [Average Reward 29.884]\n",
"[TESTING] [Total Steps: 900] [Average Reward 29.957]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 29.978]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 29.986]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 29.991]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 29.993]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 29.994]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 29.996]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 29.996]\n"
]
}
],
"source": [
"td3_stats_no_batch_norm = []\n",
"for i in range(5):\n",
" print(\"-\" * 10)\n",
" print(i)\n",
" model = TD3(policy_class=lambda s, a: TD3Net(s, a, use_batch_norm=False))\n",
" model.train()\n",
" td3_stats_no_batch_norm.append(model.run_stats)"
]
},
{
"cell_type": "markdown",
"id": "a90745c6",
"metadata": {},
"source": [
"# same things with SAC. Here with Batch Norm"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "59771363",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------\n",
"0\n",
"[TESTING] [Total Steps: 0] [Average Reward -2.863]\n",
"[TESTING] [Total Steps: 300] [Average Reward -2.637]\n",
"[TESTING] [Total Steps: 600] [Average Reward 1.463]\n",
"[TESTING] [Total Steps: 900] [Average Reward 7.295]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 9.499]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 11.115]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 13.719]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 14.046]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 17.136]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 17.300]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 18.709]\n",
"----------\n",
"1\n",
"[TESTING] [Total Steps: 0] [Average Reward -1.422]\n",
"[TESTING] [Total Steps: 300] [Average Reward 2.362]\n",
"[TESTING] [Total Steps: 600] [Average Reward 6.293]\n",
"[TESTING] [Total Steps: 900] [Average Reward 6.952]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 8.212]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 10.403]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 10.594]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 13.462]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 12.795]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 12.896]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 14.561]\n",
"----------\n",
"2\n",
"[TESTING] [Total Steps: 0] [Average Reward 2.287]\n",
"[TESTING] [Total Steps: 300] [Average Reward -3.066]\n",
"[TESTING] [Total Steps: 600] [Average Reward 4.481]\n",
"[TESTING] [Total Steps: 900] [Average Reward 8.786]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 10.877]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 12.028]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 13.635]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 15.625]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 15.664]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 16.005]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 18.392]\n",
"----------\n",
"3\n",
"[TESTING] [Total Steps: 0] [Average Reward 0.901]\n",
"[TESTING] [Total Steps: 300] [Average Reward -13.923]\n",
"[TESTING] [Total Steps: 600] [Average Reward -7.597]\n",
"[TESTING] [Total Steps: 900] [Average Reward -1.609]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 2.220]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 4.721]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 6.056]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 9.206]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 9.931]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 13.146]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 13.121]\n",
"----------\n",
"4\n",
"[TESTING] [Total Steps: 0] [Average Reward 3.851]\n",
"[TESTING] [Total Steps: 300] [Average Reward 0.013]\n",
"[TESTING] [Total Steps: 600] [Average Reward 6.251]\n",
"[TESTING] [Total Steps: 900] [Average Reward 9.068]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 11.889]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 14.886]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 16.303]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 16.039]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 15.649]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 18.159]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 19.134]\n"
]
}
],
"source": [
"sac_stats_batch_norm = []\n",
"for i in range(5):\n",
" print(\"-\" * 10)\n",
" print(i)\n",
" model = SAC(policy_class=lambda s, a: SACNet(s, a, use_batch_norm=True))\n",
" model.train()\n",
" sac_stats_batch_norm.append(model.run_stats)"
]
},
{
"cell_type": "markdown",
"id": "47cd0dd0",
"metadata": {},
"source": [
"# but removing batch norm:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d38f556f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------\n",
"0\n",
"[TESTING] [Total Steps: 0] [Average Reward 1.242]\n",
"[TESTING] [Total Steps: 300] [Average Reward 29.018]\n",
"[TESTING] [Total Steps: 600] [Average Reward 29.112]\n",
"[TESTING] [Total Steps: 900] [Average Reward 29.085]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 28.996]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 28.908]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 28.725]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 28.561]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 28.505]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 28.409]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 28.648]\n",
"----------\n",
"1\n",
"[TESTING] [Total Steps: 0] [Average Reward 1.220]\n",
"[TESTING] [Total Steps: 300] [Average Reward 29.047]\n",
"[TESTING] [Total Steps: 600] [Average Reward 29.077]\n",
"[TESTING] [Total Steps: 900] [Average Reward 29.077]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 28.946]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 28.853]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 28.770]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 28.747]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 28.660]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 28.728]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 28.611]\n",
"----------\n",
"2\n",
"[TESTING] [Total Steps: 0] [Average Reward -0.713]\n",
"[TESTING] [Total Steps: 300] [Average Reward 29.007]\n",
"[TESTING] [Total Steps: 600] [Average Reward 29.101]\n",
"[TESTING] [Total Steps: 900] [Average Reward 29.087]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 28.976]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 28.872]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 28.827]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 28.865]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 28.670]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 28.517]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 28.591]\n",
"----------\n",
"3\n",
"[TESTING] [Total Steps: 0] [Average Reward 2.078]\n",
"[TESTING] [Total Steps: 300] [Average Reward 29.010]\n",
"[TESTING] [Total Steps: 600] [Average Reward 29.154]\n",
"[TESTING] [Total Steps: 900] [Average Reward 29.040]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 29.015]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 28.910]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 28.943]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 28.848]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 28.736]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 28.749]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 28.723]\n",
"----------\n",
"4\n",
"[TESTING] [Total Steps: 0] [Average Reward -0.218]\n",
"[TESTING] [Total Steps: 300] [Average Reward 29.017]\n",
"[TESTING] [Total Steps: 600] [Average Reward 29.062]\n",
"[TESTING] [Total Steps: 900] [Average Reward 28.899]\n",
"[TESTING] [Total Steps: 1200] [Average Reward 28.882]\n",
"[TESTING] [Total Steps: 1500] [Average Reward 28.789]\n",
"[TESTING] [Total Steps: 1800] [Average Reward 28.786]\n",
"[TESTING] [Total Steps: 2100] [Average Reward 28.641]\n",
"[TESTING] [Total Steps: 2400] [Average Reward 28.682]\n",
"[TESTING] [Total Steps: 2700] [Average Reward 28.485]\n",
"[TESTING] [Total Steps: 3000] [Average Reward 28.640]\n"
]
}
],
"source": [
"sac_stats_no_batch_norm = []\n",
"for i in range(5):\n",
" print(\"-\" * 10)\n",
" print(i)\n",
" model = SAC(policy_class=lambda s, a: SACNet(s, a, use_batch_norm=False))\n",
" model.train()\n",
" sac_stats_no_batch_norm.append(model.run_stats)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "065d8e1b",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 400x400 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"def get_min_max_mean(arr):\n",
" tmp_dict = {}\n",
" for i in arr:\n",
" for num, j in enumerate(i):\n",
" val = str((num)*300)\n",
" if val in tmp_dict:\n",
" tmp_dict[val].append(j[1])\n",
" else:\n",
" tmp_dict[val] = [j[1]]\n",
"\n",
" arr_mean = {}\n",
" arr_max = {}\n",
" arr_min = {}\n",
"\n",
" for k, v in tmp_dict.items():\n",
" v = np.array(v)\n",
" arr_mean[k] = v.mean()\n",
" arr_max[k] = v.max()\n",
" arr_min[k] = v.min()\n",
" \n",
" return arr_min, arr_max, arr_mean\n",
" \n",
"ssbn_min, ssbn_max, ssbn_mean = get_min_max_mean(td3_stats_batch_norm)\n",
"ss_min, ss_max, ss_mean = get_min_max_mean(td3_stats_no_batch_norm)\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.figure(figsize=(5,5), dpi= 80)\n",
"plt.title(\"TD3 Results\")\n",
"plt.ylabel(\"Test Reward\", fontsize=16) \n",
"plt.xlabel(\"Training Steps\", fontsize=16) \n",
"x = list(ssbn_mean.keys())\n",
"plt.plot(x, list(ssbn_mean.values()), color='blue', label=\"TD3 with Batch Norm\") \n",
"plt.plot(x, list(ss_mean.values()), color='red', label=\"TD3 without Batch Norm\")\n",
"plt.axhline(y=30, color='green', linestyle='-', label=\"optimal solution\")\n",
"plt.fill_between(x, list(ssbn_min.values()), list(ssbn_max.values()), color=\"lightblue\")\n",
"plt.fill_between(x, list(ss_min.values()), list(ss_max.values()), color=\"lightcoral\")\n",
"plt.legend()\n",
"plt.savefig(\"td3_batch_norm.png\", format='png')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "a923ac68",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 400x400 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"ssbn_min, ssbn_max, ssbn_mean = get_min_max_mean(sac_stats_batch_norm)\n",
"ss_min, ss_max, ss_mean = get_min_max_mean(sac_stats_no_batch_norm)\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.figure(figsize=(5,5), dpi= 80)\n",
"plt.title(\"SAC Results\")\n",
"plt.ylabel(\"Test Reward\", fontsize=16) \n",
"plt.xlabel(\"Training Steps\", fontsize=16) \n",
"x = list(ssbn_mean.keys())\n",
"plt.plot(x, list(ssbn_mean.values()), color='blue', label=\"SAC with Batch Norm\") \n",
"plt.plot(x, list(ss_mean.values()), color='red', label=\"SAC without Batch Norm\")\n",
"plt.axhline(y=30, color='green', linestyle='-', label=\"optimal solution\")\n",
"plt.fill_between(x, list(ssbn_min.values()), list(ssbn_max.values()), color=\"lightblue\")\n",
"plt.fill_between(x, list(ss_min.values()), list(ss_max.values()), color=\"lightcoral\")\n",
"plt.legend()\n",
"plt.savefig(\"sac_batch_norm.png\", format='png')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59748169",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@honglu2875
Copy link

honglu2875 commented Aug 17, 2022

Well, when you updated the target net, you used parameters(). I can see the code is from stable-baseline which is not designed for bn. Batch norm in fact has two more variables that are not included in parameters(). What you had was in fact a partially updated target net which is dramatically different when inferencing.

@honglu2875
Copy link

Also, you abused .eval() and .train(). I suggest you learn about how batch norm works and what are the implications of those methods. I didn't look into the exact logic of each algorithms but I also have a vague feeling that something is not exactly like in the paper (but might still work).

ps: I was just able to fix your TD3 and made the batch norm model run as good as the one without. I don't have time to do SAC for you but I think most likely you should be able to fix it by yourself too if you understand batch norm correctly.

@rdednl
Copy link
Author

rdednl commented Sep 29, 2022

@honglu2875 Hi. My code is not from stable baselines. Also, batch norm learnable parameters that have to be updated on the target are present in the parameters() method:

> for model_param in model.model.actor.layers[1].parameters():
>    print(model_param.shape)

torch.Size([64])
torch.Size([64])

what are the variables that are missing?

Also, what do you mean that I abused .eval() and .train() ?

@honglu2875
Copy link

honglu2875 commented Sep 29, 2022

Check out properties whose names start with "running_" (either in your batch norm layer or state_dict). They are "learnable", meaning they change under training but not by gradients. They are not present in parameters().

All learnable parameters are in state_dict(). parameters() are only those that are updated by gradients.

@honglu2875
Copy link

honglu2875 commented Sep 29, 2022

My code is not from stable baselines.

Ahh.... So this misunderstanding spread wider than I thought... Maybe there is a chain of misuse and people never bother checking.
When stable-baseline came out there was no such thing as batch norm by the way. The code is great and should indeed be our implement baseline. But we, "the later generations", really have more responsibilities when working on earlier codes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment