Created
November 23, 2019 16:54
-
-
Save cyoon1729/8d53762194861dcfb1b6c809fd41cec5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
# Initialize alpha & associated variables as such in the beginning: | |
self.alpha = alpha | |
self.target_entropy = -torch.prod(torch.Tensor(self.env.action_space.shape).to(self.device)).item() | |
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) | |
""" | |
# Compute log_pi for the new actions sampled. | |
new_means, new_stds, new_zs, new_log_pis = self.policy_net.sample(states) | |
# Compute alpha loss. | |
alpha_loss = (self.log_alpha * (-new_log_pis - self.target_entropy).detach()).mean() | |
# Update alpha. | |
self.alpha_optim.zero_grad() | |
alpha_loss.backward() | |
self.alpha_optim.step() | |
self.alpha = self.log_alpha.exp() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment