Skip to content

Instantly share code, notes, and snippets.

@jskDr
Created September 29, 2019 13:36
Show Gist options
  • Save jskDr/9b9d7ce98d25632a0b1d5184bdbafc7c to your computer and use it in GitHub Desktop.
Save jskDr/9b9d7ce98d25632a0b1d5184bdbafc7c to your computer and use it in GitHub Desktop.
Policy Gradient with PyTorch and Python Class Structure
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"# Using ipython\n",
"from IPython.display import clear_output\n",
"\n",
"import gym\n",
"import random\n",
"from itertools import count\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Using torch\n",
"import torch as TC\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.autograd import Variable\n",
"from torch.distributions import Bernoulli"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def list_norm_inplace(buff):\n",
" r_mean = np.mean(buff)\n",
" r_std = np.std(buff)\n",
" for ii in range(len(buff)):\n",
" buff[ii] = (buff[ii] - r_mean) / r_std\n",
"\n",
" \n",
"def plot_durations(episode_durations):\n",
" plt.figure(2)\n",
" plt.clf()\n",
" durations_t = TC.FloatTensor(episode_durations)\n",
" plt.title('Training...')\n",
" plt.xlabel('Episode')\n",
" plt.ylabel('Duration')\n",
" plt.plot(durations_t.numpy())\n",
" # Take 100 episode averages and plot them too\n",
" if len(durations_t) >= 100:\n",
" means = durations_t.unfold(0, 100, 1).mean(1).view(-1)\n",
" means = TC.cat((TC.zeros(99), means))\n",
" plt.plot(means.numpy())\n",
" plt.show()\n",
" \n",
" \n",
"def plot_durations_ii(ii, episode_durations, ee, ee_duration=100):\n",
" episode_durations.append(ii+1)\n",
" if (ee + 1) % ee_duration == 0:\n",
" clear_output()\n",
" plot_durations(episode_durations)\n",
" \n",
" \n",
"class PGNET(nn.Module):\n",
" # 순수하게 Policy gradient로 구성\n",
" def __init__(self, num_state):\n",
" super(PGNET, self).__init__()\n",
" \n",
" self.fc_in = nn.Linear(num_state, 24)\n",
" self.fc_hidden = nn.Linear(24, 36)\n",
" self.fc_out = nn.Linear(36, 1)\n",
" \n",
" def forward(self, x):\n",
" x = F.relu(self.fc_in(x))\n",
" x = F.relu(self.fc_hidden(x))\n",
" x = TC.sigmoid(self.fc_out(x))\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class PGNET_MACHINE(PGNET):\n",
" # Policy gradient를 실제 사용하기 위해 필요한 변환을 고려함\n",
" def __init__(self, num_state, render_flag=False):\n",
" self.forget_factor = 0.99\n",
" self.learning_rate = 0.01\n",
" self.num_episode = 5000 \n",
" self.num_batch = 5\n",
" self.render_flag = render_flag\n",
" self.steps_in_batch = 0\n",
" self.episode_durations = []\n",
" \n",
" super(PGNET_MACHINE, self).__init__(num_state) \n",
" self.optimizer = TC.optim.RMSprop(self.parameters(), lr=self.learning_rate) \n",
" self.init_buff() \n",
" \n",
" def forward(self, state):\n",
" state_var = Variable(TC.from_numpy(state).float())\n",
" prob = super(PGNET_MACHINE, self).forward(state_var)\n",
" # check this function to know what the meaing of log_prob. Is it log(prob)?\n",
" return Bernoulli(prob)\n",
" \n",
" def push_buff_done(self, reward, state, action, done_flag=False):\n",
" if done_flag: # devide between episodes\n",
" self.reward_buff.append(0)\n",
" else:\n",
" self.reward_buff.append(reward)\n",
" self.state_buff.append(state)\n",
" self.action_buff.append(action)\n",
" \n",
" def pull_buff(self, ii):\n",
" return self.reward_buff[ii], self.state_buff[ii], self.action_buff[ii]\n",
" \n",
" def init_buff(self):\n",
" self.reward_buff = []\n",
" self.state_buff = []\n",
" self.action_buff = []\n",
" \n",
" def transform_discount_reward(self, steps):\n",
" future_reward = 0\n",
" for ii in reversed(range(steps)):\n",
" if self.reward_buff[ii] == 0:\n",
" future_reward = 0\n",
" else:\n",
" future_reward = future_reward * self.forget_factor + self.reward_buff[ii]\n",
" self.reward_buff[ii] = future_reward\n",
" list_norm_inplace(self.reward_buff)\n",
" \n",
" def train(self, steps):\n",
" self.transform_discount_reward(steps)\n",
" \n",
" self.optimizer.zero_grad()\n",
" for ii in range(steps):\n",
" reward, state, action = self.pull_buff(ii)\n",
" #state_var = Variable(TC.from_numpy(state).float())\n",
" action_var = Variable(TC.FloatTensor([float(action)]))\n",
" #policy = self.forward(state_var)\n",
" policy = self.forward(state)\n",
" loss = -policy.log_prob(action_var) * reward \n",
" loss.backward()\n",
" self.optimizer.step()\n",
" \n",
" self.init_buff()\n",
"\n",
" def step(self, env, state, ee, ii, ee_duration=100): \n",
" policy = self.forward(state)\n",
" action = policy.sample().data.numpy().astype(int)[0]\n",
"\n",
" next_state, reward, done_flag, _ = env.step(action)\n",
" if self.render_flag: \n",
" env.render() \n",
" self.push_buff_done(reward, state, action, done_flag)\n",
"\n",
" self.steps_in_batch += 1\n",
" state = next_state\n",
" \n",
" return state, done_flag\n",
" \n",
" def run_episode(self, env, ee):\n",
" state = env.reset() \n",
" for ii in count(): \n",
" state, done_flag = self.step(env, state, ee, ii, ee_duration=100)\n",
" if done_flag:\n",
" plot_durations_ii(ii, self.episode_durations, ee, ee_duration=100)\n",
" break\n",
"\n",
" def train_episode(self, ee):\n",
" if ee > 0 and ee % self.num_batch == 0:\n",
" self.train(self.steps_in_batch) \n",
" self.steps_in_batch = 0 \n",
" \n",
" def run(self, env):\n",
" for ee in range(self.num_episode):\n",
" self.run_episode(env, ee)\n",
" self.train_episode(ee)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def main():\n",
" env = gym.make('CartPole-v0') \n",
" mypgnet = PGNET_MACHINE(env.observation_space.shape[0], render_flag=False)\n",
" mypgnet.run(env) \n",
" env.close() "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"main()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch",
"language": "python",
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment