Skip to content

Instantly share code, notes, and snippets.

@cyoon1729
Created November 23, 2019 16:54
Show Gist options
  • Save cyoon1729/8d53762194861dcfb1b6c809fd41cec5 to your computer and use it in GitHub Desktop.
Save cyoon1729/8d53762194861dcfb1b6c809fd41cec5 to your computer and use it in GitHub Desktop.
"""
# Initialize alpha & associated variables as such in the beginning:
self.alpha = alpha
self.target_entropy = -torch.prod(torch.Tensor(self.env.action_space.shape).to(self.device)).item()
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
"""
# Compute log_pi for the new actions sampled.
new_means, new_stds, new_zs, new_log_pis = self.policy_net.sample(states)
# Compute alpha loss.
alpha_loss = (self.log_alpha * (-new_log_pis - self.target_entropy).detach()).mean()
# Update alpha.
self.alpha_optim.zero_grad()
alpha_loss.backward()
self.alpha_optim.step()
self.alpha = self.log_alpha.exp()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment