Skip to content

Instantly share code, notes, and snippets.

@xpe
Last active December 23, 2018 21:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xpe/3e2930719f0feb0b7aaa470a80009845 to your computer and use it in GitHub Desktop.
Save xpe/3e2930719f0feb0b7aaa470a80009845 to your computer and use it in GitHub Desktop.
Policy Gradient Loss function for PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch._jit_internal import weak_module, weak_script_method
@weak_module
class PolicyGradientLoss(nn.Module):
"""
Multiplies an unreduced CrossEntropyLoss by a `q` vector.
"""
def __init__(self):
super(PolicyGradientLoss, self).__init__()
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
@weak_script_method
def forward(self, input_, target, q):
cel = self.cross_entropy_loss.forward(input_, target)
return torch.mean(cel * q)
@xpe
Copy link
Author

xpe commented Dec 23, 2018

I'm not yet clear on the meaning (or necessity) of the annotations @weak_module and @weak_script_method. These are used in PyTorch source code, but I don't know of their importance in end-user code.

@xpe
Copy link
Author

xpe commented Dec 23, 2018

Note: I use input_ to appease my editor, since input is a Python built-in. See https://stackoverflow.com/a/20670757/109618

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment