Skip to content

Instantly share code, notes, and snippets.

@maschere
Last active May 9, 2024 07:50
Show Gist options
  • Save maschere/d6e5157c1946e5326f60dd9e6915309c to your computer and use it in GitHub Desktop.
Save maschere/d6e5157c1946e5326f60dd9e6915309c to your computer and use it in GitHub Desktop.
PG for Cartpole
#mostly from https://github.com/Finspire13/pytorch-policy-gradient-example/blob/master/pg.py
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Bernoulli
from torch.autograd import Variable
from itertools import count
import matplotlib.pyplot as plt
import numpy as np
import gymnasium as gym
class PolicyNet(nn.Module):
def __init__(self):
super(PolicyNet, self).__init__()
self.fc1 = nn.Linear(4, 24)
self.fc2 = nn.Linear(24, 36)
self.fc3 = nn.Linear(36, 1) # Prob of Left
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.sigmoid(self.fc3(x))
return x
def main():
# Plot duration curve:
# From http://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
episode_durations = []
def plot_durations():
plt.figure(2)
plt.clf()
durations_t = torch.FloatTensor(episode_durations)
plt.title('Training...')
plt.xlabel('Episode')
plt.ylabel('Duration')
plt.plot(durations_t.numpy())
# Take 100 episode averages and plot them too
if len(durations_t) >= 100:
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
means = torch.cat((torch.zeros(99), means))
plt.plot(means.numpy())
plt.pause(0.001) # pause a bit so that plots are updated
# Parameters
num_episode = 5000
batch_size = 5
learning_rate = 0.01
gamma = 0.99
env = gym.make('CartPole-v1', render_mode="rgb_array")
policy_net = PolicyNet()
optimizer = torch.optim.RMSprop(policy_net.parameters(), lr=learning_rate)
# Batch History
state_pool = []
action_pool = []
reward_pool = []
steps = 0
for e in range(num_episode):
state, info = env.reset()
state = torch.from_numpy(state).float()
state = Variable(state)
#env.render()
for t in count():
probs = policy_net(state)
m = Bernoulli(probs)
#action always left (0) or right(1)
action = m.sample()
print(action)
action = action.data.numpy().astype(int)[0]
#reward 1 if ok, else 0
next_state, reward, done, _, _ = env.step(action)
if (e%100==0):
plt.figure(1)
plt.clf()
plt.imshow(env.render())
plt.pause(0.01)
# To mark boundarys between episodes
if done:
reward = 0
state_pool.append(state)
action_pool.append(float(action))
reward_pool.append(reward)
state = next_state
state = torch.from_numpy(state).float()
state = Variable(state)
steps += 1
if done:
episode_durations.append(t + 1)
plot_durations()
break
# Update policy
if e > 0 and e % batch_size == 0:
# cumulate and discount rewards
running_add = 0
for i in reversed(range(steps)):
if reward_pool[i] == 0:
running_add = 0
else:
running_add = running_add * gamma + reward_pool[i]
reward_pool[i] = running_add
# Normalize reward
reward_mean = np.mean(reward_pool)
reward_std = np.std(reward_pool)
for i in range(steps):
reward_pool[i] = (reward_pool[i] - reward_mean) / reward_std
# Gradient Desent
optimizer.zero_grad()
for i in range(steps):
state = state_pool[i]
action = Variable(torch.FloatTensor([action_pool[i]]))
reward = reward_pool[i]
probs = policy_net(state)
m = Bernoulli(probs)
#diff between predicted action prob and observed action prob
loss = -m.log_prob(action) * reward # Negtive score function x reward
loss.backward()
optimizer.step()
#clear history
state_pool = []
action_pool = []
reward_pool = []
steps = 0
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment