Skip to content

Instantly share code, notes, and snippets.

View cyoon1729's full-sized avatar
🎯
Focusing

Chris Yoon cyoon1729

🎯
Focusing
  • Columbia University
  • New York, New York
View GitHub Profile
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import ray
import random
import numpy as np
from collections import deque
# To make a python object share-able across,
# we just have to add a ray.remote decorator on the top
@ray.remote
class Storage(object):
"""
# Initialize alpha & associated variables as such in the beginning:
self.alpha = alpha
self.target_entropy = -torch.prod(torch.Tensor(self.env.action_space.shape).to(self.device)).item()
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
"""
# Compute log_pi for the new actions sampled.
new_means, new_stds, new_zs, new_log_pis = self.policy_net.sample(states)
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
z = normal.rsample()
log_pi = (normal.log_prob(z) - torch.log(1 - (torch.tanh(z)).pow(2) + epsilon)).sum(1, keepdim=True)
# Obtain new action for the current states given out updated Q network.
new_means, new_stds, new_zs, new_log_pis = self.policy_net.sample(states)
new_actions = torch.tanh(new_zs)
# Compute policy loss.
new_q = self.q_net.forward(states, new_actions)
policy_loss = (self.alpha * new_log_pis - new_q).mean()
# Update poicy parameters.
self.policy_optimizer.zero_grad()
# Assume we've already sampled states, actions, rewards, is_terminal_step, next_states
# from the replay buffer
# sample new actions
next_means, next_stds, next_zs, next_log_pi = self.policy_net.sample(next_states)
next_actions = torch.tanh(next_zs)
# compute boostrap q values
next_q = self.target_q_net(next_states, next_actions)
next_q_target = next_q - self.alpha * next_log_pi