Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ray | |
import random | |
import numpy as np | |
from collections import deque | |
# To make a python object share-able across, | |
# we just have to add a ray.remote decorator on the top | |
@ray.remote | |
class Storage(object): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
# Initialize alpha & associated variables as such in the beginning: | |
self.alpha = alpha | |
self.target_entropy = -torch.prod(torch.Tensor(self.env.action_space.shape).to(self.device)).item() | |
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) | |
""" | |
# Compute log_pi for the new actions sampled. | |
new_means, new_stds, new_zs, new_log_pis = self.policy_net.sample(states) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mean, log_std = self.forward(state) | |
std = log_std.exp() | |
normal = Normal(mean, std) | |
z = normal.rsample() | |
log_pi = (normal.log_prob(z) - torch.log(1 - (torch.tanh(z)).pow(2) + epsilon)).sum(1, keepdim=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Obtain new action for the current states given out updated Q network. | |
new_means, new_stds, new_zs, new_log_pis = self.policy_net.sample(states) | |
new_actions = torch.tanh(new_zs) | |
# Compute policy loss. | |
new_q = self.q_net.forward(states, new_actions) | |
policy_loss = (self.alpha * new_log_pis - new_q).mean() | |
# Update poicy parameters. | |
self.policy_optimizer.zero_grad() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Assume we've already sampled states, actions, rewards, is_terminal_step, next_states | |
# from the replay buffer | |
# sample new actions | |
next_means, next_stds, next_zs, next_log_pi = self.policy_net.sample(next_states) | |
next_actions = torch.tanh(next_zs) | |
# compute boostrap q values | |
next_q = self.target_q_net(next_states, next_actions) | |
next_q_target = next_q - self.alpha * next_log_pi |
NewerOlder