Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save rohan-paul/dc0633271db72d9a2ea7fd15a7807e01 to your computer and use it in GitHub Desktop.
Save rohan-paul/dc0633271db72d9a2ea7fd15a7807e01 to your computer and use it in GitHub Desktop.
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
## policy gradient loss
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
1.0 + self.cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
return pg_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment