Skip to content

Instantly share code, notes, and snippets.

@jskDr
Last active October 3, 2019 15:04
Show Gist options
  • Save jskDr/9173432937ca00755be849f14672e98c to your computer and use it in GitHub Desktop.
Save jskDr/9173432937ca00755be849f14672e98c to your computer and use it in GitHub Desktop.
Actor-Critic implemented by PyTorch, separated loss formulations are used for actor and critic agents.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"from IPython.display import clear_output\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import torch.nn as nn\n",
"from torch.distributions import Categorical, Bernoulli\n",
"import torch.optim as optim\n",
"from torch.autograd import Variable\n",
"\n",
"import gym\n",
"\n",
"from itertools import count\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class ActorNet(nn.Module):\n",
" def __init__(self, state_size: int, action_size: int):\n",
" super(ActorNet, self).__init__()\n",
" self.fc1 = nn.Linear(state_size, 24)\n",
" self.fc2 = nn.Linear(24, 36)\n",
" self.fc3 = nn.Linear(36, action_size)\n",
" #self.fc3 = nn.Linear(36, 1)\n",
" \n",
" def forward(self, x):\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" \n",
" prob = F.softmax(self.fc3(x), dim=-1)\n",
" policy = Categorical(prob)\n",
" #prob = torch.sigmoid(self.fc3(x))\n",
" #policy = Bernoulli(prob) \n",
" return policy\n",
" \n",
"class CriticNet(nn.Module):\n",
" def __init__(self, state_size: int):\n",
" super(CriticNet, self).__init__()\n",
" #self.state_size = state_size\n",
" self.fc1 = nn.Linear(state_size, 24)\n",
" self.fc2 = nn.Linear(24, 36)\n",
" self.fc3 = nn.Linear(36, 1)\n",
" \n",
" def forward(self, x):\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" value = self.fc3(x)\n",
" return value"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def calc_discount_reward(reward_buff, gamma=0.99, normalize=True, done_flag=True):\n",
" prev_dr = 0\n",
" for ii in reversed(range(len(reward_buff))):\n",
" if done_flag and reward_buff[ii] == 0:\n",
" prev_dr = 0 \n",
" else:\n",
" reward_buff[ii] += prev_dr * gamma\n",
" prev_dr = reward_buff[ii]\n",
" \n",
" if normalize:\n",
" mean, std = np.mean(reward_buff), np.std(reward_buff)\n",
" for ii in range(len(reward_buff)):\n",
" reward_buff[ii] = (reward_buff[ii] - mean) / std"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode:499\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"num_episodes = 5000\n",
"durations_list = []\n",
"render_flag = False\n",
"lr = 3e-2\n",
"num_batch = 5\n",
"\n",
"def init_buff():\n",
" global state_buff, action_buff, reward_buff, value_buff\n",
" state_buff = []\n",
" action_buff = []\n",
" reward_buff = []\n",
" value_buff = []\n",
" \n",
"def run(env):\n",
" actor = ActorNet(env.observation_space.shape[0], env.action_space.n) \n",
" critic = CriticNet(env.observation_space.shape[0])\n",
" optimizer_actor = optim.Adam(actor.parameters(), lr=lr)\n",
" optimizer_critic = optim.Adam(critic.parameters(), lr=lr)\n",
" \n",
" init_buff()\n",
" for ep in range(num_episodes):\n",
" state_numpy = env.reset()\n",
" for t in count():\n",
" state_torch = torch.FloatTensor(state_numpy)\n",
" policy_torch, value_torch = actor(state_torch), critic(state_torch)\n",
" action_torch = policy_torch.sample()\n",
" action_int = action_torch.numpy() #.astype(int)[0]\n",
" next_state_numpy, reward, done, _ = env.step(action_int)\n",
" if render_flag:\n",
" env.render() \n",
" if done:\n",
" reward = 0\n",
" \n",
" state_buff.append(state_torch)\n",
" action_buff.append(action_torch)\n",
" reward_buff.append(reward)\n",
" value_buff.append(value_torch)\n",
" \n",
" if done:\n",
" durations_list.append(t+1)\n",
" break\n",
" state_numpy = next_state_numpy\n",
" \n",
" if ep % num_batch == num_batch - 1:\n",
" calc_discount_reward(reward_buff)\n",
" optimizer_actor.zero_grad()\n",
" optimizer_critic.zero_grad()\n",
" for state_torch, action_torch, reward, value_torch in zip(state_buff, action_buff, reward_buff, value_buff):\n",
" advantage_torch = reward - value_torch\n",
" critic_loss = advantage_torch.pow(2)\n",
" policy = actor(Variable(state_torch))\n",
" actor_loss = -policy.log_prob(Variable(action_torch)) * advantage_torch.detach() \n",
" actor_loss.backward()\n",
" critic_loss.backward()\n",
" optimizer_actor.step() \n",
" optimizer_critic.step()\n",
" init_buff()\n",
" \n",
" if ep % 100 == 99:\n",
" clear_output()\n",
" print(f'Episode:{ep}')\n",
" plt.plot(durations_list)\n",
" plt.show()\n",
" \n",
" print('Final results')\n",
" plt.plot(durations_list)\n",
" plt.show() \n",
"\n",
"env = gym.make('CartPole-v0')\n",
"run(env)\n",
"env.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch",
"language": "python",
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment