Created
May 10, 2024 21:02
-
-
Save ziyan0302/e90d8329323df3da458055199f409b9e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy.signal | |
from gym.spaces import Box, Discrete | |
import torch | |
import torch.nn as nn | |
from torch.distributions.normal import Normal | |
from torch.distributions.categorical import Categorical | |
import torch.nn.functional as F | |
import pdb | |
def combined_shape(length, shape=None): | |
if shape is None: | |
return (length,) | |
return (length, shape) if np.isscalar(shape) else (length, *shape) | |
def mlp(sizes, activation, output_activation=nn.Identity): | |
layers = [] | |
for j in range(len(sizes)-1): | |
act = activation if j < len(sizes)-2 else output_activation | |
layers += [nn.Linear(sizes[j], sizes[j+1]), act()] | |
return nn.Sequential(*layers) | |
def count_vars(module): | |
return sum([np.prod(p.shape) for p in module.parameters()]) | |
def discount_cumsum(x, discount): | |
""" | |
magic from rllab for computing discounted cumulative sums of vectors. | |
input: | |
vector x, | |
[x0, | |
x1, | |
x2] | |
output: | |
[x0 + discount * x1 + discount^2 * x2, | |
x1 + discount * x2, | |
x2] | |
""" | |
return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1] | |
class Actor(nn.Module): | |
def _distribution(self, obs): | |
raise NotImplementedError | |
def _log_prob_from_distribution(self, pi, act): | |
raise NotImplementedError | |
def forward(self, obs, act=None): | |
# Produce action distributions for given observations, and | |
# optionally compute the log likelihood of given actions under | |
# those distributions. | |
pi = self._distribution(obs) | |
logp_a = None | |
if act is not None: | |
logp_a = self._log_prob_from_distribution(pi, act) | |
return pi, logp_a | |
class MLPCategoricalActor(Actor): | |
def __init__(self, obs_dim, act_dim, hidden_sizes, activation): | |
super().__init__() | |
self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation) | |
def _distribution(self, obs): | |
logits = self.logits_net(obs) | |
return Categorical(logits=logits) | |
def _log_prob_from_distribution(self, pi, act): | |
return pi.log_prob(act) | |
class MLPGaussianActor(Actor): | |
def __init__(self, obs_dim, act_dim, hidden_sizes, activation, fixed_var=True): | |
super().__init__() | |
log_std = -0.5 * np.ones(act_dim, dtype=np.float32) | |
# self.log_std = torch.nn.Parameter(torch.as_tensor(log_std)) | |
self.net = mlp([obs_dim] + list(hidden_sizes) + [act_dim*2], activation) | |
# self.fc1 = nn.Linear(obs_dim, hidden_sizes[0]) | |
# self.fc2_std = nn.Linear(hidden_sizes[1], act_dim) | |
self.fixed_var = fixed_var | |
def _distribution(self, obs, fixed_var=True): | |
mu_std = self.net(obs) | |
if (len(mu_std.shape) == 1 ): | |
act_dim = int(mu_std.shape[0]/2) | |
mu = mu_std[:act_dim] | |
std = torch.exp(mu_std[act_dim:]) | |
if (len(mu_std.shape) == 2): | |
act_dim = int(mu_std.shape[1]/2) | |
mu = mu_std[:, :act_dim] | |
std = torch.exp(mu_std[:, act_dim:]) | |
# std = torch.exp(self.log_std) | |
# x = F.relu(self.fc1(obs)) | |
# if self.fixed_var: | |
# std = torch.exp(self.fc2_std(x)) # Using fixed variance (log(std)) | |
# else: | |
# std = torch.softplus(self.fc2_std(x)) # Using learned variance (softplus) | |
return Normal(mu, std) | |
def _log_prob_from_distribution(self, pi, act): | |
# print("act: ", act) | |
return pi.log_prob(act).sum(axis=-1) # Last axis sum needed for Torch Normal distribution | |
class MLPCritic(nn.Module): | |
def __init__(self, obs_dim, hidden_sizes, activation): | |
super().__init__() | |
self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation) | |
def forward(self, obs): | |
return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape. | |
class MLPActorCritic(nn.Module): | |
def __init__(self, obs_dim, act_dim, | |
hidden_sizes=(64,64), activation=nn.Tanh): | |
super().__init__() | |
# policy builder depends on action space | |
# if isinstance(action_space, Box): | |
self.pi = MLPGaussianActor(obs_dim, act_dim, hidden_sizes, activation, True) | |
# elif isinstance(action_space, Discrete): | |
# self.pi = MLPCategoricalActor(obs_dim, action_space.n, hidden_sizes, activation) | |
# build value function | |
self.v = MLPCritic(obs_dim, hidden_sizes, activation) | |
def step(self, obs): | |
with torch.no_grad(): | |
pi = self.pi._distribution(obs) | |
## TODO: del sample() | |
a = pi.sample() | |
logp_a = self.pi._log_prob_from_distribution(pi, a) | |
v = self.v(obs) | |
return a.cpu().numpy(), v.cpu().numpy(), logp_a.cpu().numpy() | |
def act(self, obs): | |
return self.step(obs)[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment