Skip to content

Instantly share code, notes, and snippets.

@djbyrne
Created October 7, 2020 07:27
Show Gist options
  • Save djbyrne/a8b199ece03b24774c4ad8e9bdb2d2f5 to your computer and use it in GitHub Desktop.
Save djbyrne/a8b199ece03b24774c4ad8e9bdb2d2f5 to your computer and use it in GitHub Desktop.
def loss(self, states, actions, scaled_rewards) -> torch.Tensor:
logits = self.net(states)
# policy loss
log_prob = log_softmax(logits, dim=1)
log_prob_actions = scaled_rewards * log_prob[range(self.batch_size), actions[0]]
policy_loss = -log_prob_actions.mean()
# entropy loss
prob = softmax(logits, dim=1)
entropy = -(prob * log_prob).sum(dim=1).mean()
entropy_loss = -self.entropy_beta * entropy
# total loss
loss = policy_loss + entropy_loss
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment