Skip to content

Instantly share code, notes, and snippets.

@willtryagain
Created November 12, 2021 10:15
Show Gist options
  • Save willtryagain/7d9e15e5d5178fbc55e23b4c47b6aa84 to your computer and use it in GitHub Desktop.
Save willtryagain/7d9e15e5d5178fbc55e23b4c47b6aa84 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import gym #requires OpenAI gym installed
import sklearn
import sklearn.preprocessing
env = gym.envs.make("MountainCarContinuous-v0")
class ValueFunction(nn.Module):
"""
implements critic
"""
def __init__(self):
super().__init__()
self.hidden1 = nn.Linear(env.observation_space.shape[0], 400)
self.hidden2 = nn.Linear(400, 400)
self.V = nn.Linear(400, env.action_space.n)
# action & reward buffer
self.saved_actions = []
self.rewards = []
def forward(self, x):
"""
"""
x = F.elu(self.hidden1(x))
x = F.elu(self.hidden2(x))
x = self.V(x)
return x
class Policy(nn.Module):
def __init__(self):
super().__init__()
self.hidden1 = nn.Linear(env.observation_space.shape[0], 40)
self.hidden2 = nn.Linear(40, 40)
self.mu = nn.Linear(40, 1)
self.sigma = nn.Linear(40, 1)
# action & reward buffer
self.saved_actions = []
self.rewards = []
def forward(self, x):
"""
"""
x = F.elu(self.hidden1(x))
x = F.elu(self.hidden2(x))
mu = self.mu(x)
sigma = self.sigma(x)
sigma = nn.Softplus(sigma) + 1e-5
norm_dist = torch.normal(mu, sigma)
action = torch.empty(1).normal_(mean=mu,std=sigma)
action = torch.clamp(action, env.action_space.low[0],
env.action_space.high[0])
return action, norm_dist
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform(m.weight)
m.bias.data.fill_(0.01)
state_space_samples = np.array(
[env.observation_space.sample() for x in range(10000)])
scaler = sklearn.preprocessing.StandardScaler()
scaler.fit(state_space_samples)
def scale_state(state): #requires input shape=(2,)
scaled = scaler.transform([state])
return scaled
lr_actor = 0.00002 #set learning rates
lr_critic = 0.001
gamma = 0.99 #discount factor
num_episodes = 300
value_func = ValueFunction()
value_func.apply(init_weights)
policy = Policy()
policy.apply(init_weights)
episode_history = []
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment