Skip to content

Instantly share code, notes, and snippets.

@ziyan0302
Created May 10, 2024 21:02
Show Gist options
  • Save ziyan0302/e90d8329323df3da458055199f409b9e to your computer and use it in GitHub Desktop.
Save ziyan0302/e90d8329323df3da458055199f409b9e to your computer and use it in GitHub Desktop.
import numpy as np
import scipy.signal
from gym.spaces import Box, Discrete
import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
import torch.nn.functional as F
import pdb
def combined_shape(length, shape=None):
if shape is None:
return (length,)
return (length, shape) if np.isscalar(shape) else (length, *shape)
def mlp(sizes, activation, output_activation=nn.Identity):
layers = []
for j in range(len(sizes)-1):
act = activation if j < len(sizes)-2 else output_activation
layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
return nn.Sequential(*layers)
def count_vars(module):
return sum([np.prod(p.shape) for p in module.parameters()])
def discount_cumsum(x, discount):
"""
magic from rllab for computing discounted cumulative sums of vectors.
input:
vector x,
[x0,
x1,
x2]
output:
[x0 + discount * x1 + discount^2 * x2,
x1 + discount * x2,
x2]
"""
return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]
class Actor(nn.Module):
def _distribution(self, obs):
raise NotImplementedError
def _log_prob_from_distribution(self, pi, act):
raise NotImplementedError
def forward(self, obs, act=None):
# Produce action distributions for given observations, and
# optionally compute the log likelihood of given actions under
# those distributions.
pi = self._distribution(obs)
logp_a = None
if act is not None:
logp_a = self._log_prob_from_distribution(pi, act)
return pi, logp_a
class MLPCategoricalActor(Actor):
def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
super().__init__()
self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
def _distribution(self, obs):
logits = self.logits_net(obs)
return Categorical(logits=logits)
def _log_prob_from_distribution(self, pi, act):
return pi.log_prob(act)
class MLPGaussianActor(Actor):
def __init__(self, obs_dim, act_dim, hidden_sizes, activation, fixed_var=True):
super().__init__()
log_std = -0.5 * np.ones(act_dim, dtype=np.float32)
# self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
self.net = mlp([obs_dim] + list(hidden_sizes) + [act_dim*2], activation)
# self.fc1 = nn.Linear(obs_dim, hidden_sizes[0])
# self.fc2_std = nn.Linear(hidden_sizes[1], act_dim)
self.fixed_var = fixed_var
def _distribution(self, obs, fixed_var=True):
mu_std = self.net(obs)
if (len(mu_std.shape) == 1 ):
act_dim = int(mu_std.shape[0]/2)
mu = mu_std[:act_dim]
std = torch.exp(mu_std[act_dim:])
if (len(mu_std.shape) == 2):
act_dim = int(mu_std.shape[1]/2)
mu = mu_std[:, :act_dim]
std = torch.exp(mu_std[:, act_dim:])
# std = torch.exp(self.log_std)
# x = F.relu(self.fc1(obs))
# if self.fixed_var:
# std = torch.exp(self.fc2_std(x)) # Using fixed variance (log(std))
# else:
# std = torch.softplus(self.fc2_std(x)) # Using learned variance (softplus)
return Normal(mu, std)
def _log_prob_from_distribution(self, pi, act):
# print("act: ", act)
return pi.log_prob(act).sum(axis=-1) # Last axis sum needed for Torch Normal distribution
class MLPCritic(nn.Module):
def __init__(self, obs_dim, hidden_sizes, activation):
super().__init__()
self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation)
def forward(self, obs):
return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape.
class MLPActorCritic(nn.Module):
def __init__(self, obs_dim, act_dim,
hidden_sizes=(64,64), activation=nn.Tanh):
super().__init__()
# policy builder depends on action space
# if isinstance(action_space, Box):
self.pi = MLPGaussianActor(obs_dim, act_dim, hidden_sizes, activation, True)
# elif isinstance(action_space, Discrete):
# self.pi = MLPCategoricalActor(obs_dim, action_space.n, hidden_sizes, activation)
# build value function
self.v = MLPCritic(obs_dim, hidden_sizes, activation)
def step(self, obs):
with torch.no_grad():
pi = self.pi._distribution(obs)
## TODO: del sample()
a = pi.sample()
logp_a = self.pi._log_prob_from_distribution(pi, a)
v = self.v(obs)
return a.cpu().numpy(), v.cpu().numpy(), logp_a.cpu().numpy()
def act(self, obs):
return self.step(obs)[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment