Skip to content

Instantly share code, notes, and snippets.

@jskDr
Last active October 3, 2019 14:08
Show Gist options
  • Save jskDr/3e3db8c2c1b67a7dc02935bd6bb84265 to your computer and use it in GitHub Desktop.
Save jskDr/3e3db8c2c1b67a7dc02935bd6bb84265 to your computer and use it in GitHub Desktop.
Policy gradient code written by PyTorch where the number of batches is larger than one
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"from IPython.display import clear_output\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import torch.nn as nn\n",
"from torch.distributions import Categorical, Bernoulli\n",
"import torch.optim as optim\n",
"from torch.autograd import Variable\n",
"\n",
"import gym\n",
"\n",
"from itertools import count\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class ActorNet(nn.Module):\n",
" def __init__(self, state_size: int, action_size: int):\n",
" super(ActorNet, self).__init__()\n",
" self.fc1 = nn.Linear(state_size, 24)\n",
" self.fc2 = nn.Linear(24, 36)\n",
" #self.fc3 = nn.Linear(36, action_size)\n",
" self.fc3 = nn.Linear(36, 1)\n",
" \n",
" def forward(self, x):\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" \n",
" #prob = F.softmax(self.fc3(x), dim=-1)\n",
" #policy = Categorical(prob)\n",
" prob = torch.sigmoid(self.fc3(x))\n",
" policy = Bernoulli(prob) \n",
" return policy\n",
" \n",
"class CriticNet(nn.Module):\n",
" def __init__(self, state_size: int):\n",
" super(CriticNet, self).__init__()\n",
" #self.state_size = state_size\n",
" self.fc1 = nn.Linear(state_size, 128)\n",
" self.fc2 = nn.Linear(128, 256)\n",
" self.fc3 = nn.Linear(256, 1)\n",
" \n",
" def forward(self, x):\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" value = self.fc3(x)\n",
" return value"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def calc_discount_reward(reward_buff, gamma=0.99, normalize=True, done_flag=True):\n",
" prev_dr = 0\n",
" for ii in reversed(range(len(reward_buff))):\n",
" if done_flag and reward_buff[ii] == 0:\n",
" prev_dr = 0 \n",
" else:\n",
" reward_buff[ii] += prev_dr * gamma\n",
" prev_dr = reward_buff[ii]\n",
" \n",
" if normalize:\n",
" mean, std = np.mean(reward_buff), np.std(reward_buff)\n",
" for ii in range(len(reward_buff)):\n",
" reward_buff[ii] = (reward_buff[ii] - mean) / std"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode:1999\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final results\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"num_episodes = 2000\n",
"durations_list = []\n",
"render_flag = False\n",
"lr = 3e-2\n",
"num_batch = 5\n",
"\n",
"def init_buff():\n",
" global state_buff, action_buff, reward_buff\n",
" state_buff = []\n",
" action_buff = []\n",
" reward_buff = []\n",
" \n",
"def run(env):\n",
" actor = ActorNet(env.observation_space.shape[0], env.action_space.n) \n",
" critic = CriticNet(env.observation_space.shape[0])\n",
" optimizer_actor = optim.Adam(actor.parameters(), lr=lr)\n",
" \n",
" init_buff()\n",
" for ep in range(num_episodes):\n",
" state_numpy = env.reset()\n",
" for t in count():\n",
" state_torch = torch.FloatTensor(state_numpy)\n",
" policy, value = actor(state_torch), critic(state_torch)\n",
" action_torch = policy.sample()\n",
" action_int = action_torch.numpy().astype(int)[0]\n",
" next_state_numpy, reward, done, _ = env.step(action_int)\n",
" if render_flag:\n",
" env.render() \n",
" if done:\n",
" reward = 0\n",
" \n",
" state_buff.append(state_torch)\n",
" action_buff.append(action_torch)\n",
" reward_buff.append(reward)\n",
" \n",
" if done:\n",
" durations_list.append(t+1)\n",
" break\n",
" state_numpy = next_state_numpy\n",
" \n",
" if ep % num_batch == num_batch - 1:\n",
" calc_discount_reward(reward_buff)\n",
" optimizer_actor.zero_grad()\n",
" for state_torch, action_torch, reward in zip(state_buff, action_buff, reward_buff):\n",
" policy = actor(Variable(state_torch))\n",
" loss = -policy.log_prob(Variable(action_torch)) * reward \n",
" loss.backward()\n",
" optimizer_actor.step() \n",
" init_buff()\n",
" \n",
" if ep % 100 == 99:\n",
" clear_output()\n",
" print(f'Episode:{ep}')\n",
" plt.plot(durations_list)\n",
" plt.show()\n",
" \n",
" print('Final results')\n",
" plt.plot(durations_list)\n",
" plt.show() \n",
"\n",
"env = gym.make('CartPole-v0')\n",
"run(env)\n",
"env.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch",
"language": "python",
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment