Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Policy Gradient Loss function for PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch._jit_internal import weak_module, weak_script_method
@weak_module
class PolicyGradientLoss(nn.Module):
"""
Multiplies an unreduced CrossEntropyLoss by a `q` vector.
"""
def __init__(self):
super(PolicyGradientLoss, self).__init__()
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
@weak_script_method
def forward(self, input_, target, q):
cel = self.cross_entropy_loss.forward(input_, target)
return torch.mean(cel * q)
@xpe

This comment has been minimized.

Copy link
Owner Author

commented Dec 23, 2018

I'm not yet clear on the meaning (or necessity) of the annotations @weak_module and @weak_script_method. These are used in PyTorch source code, but I don't know of their importance in end-user code.

@xpe

This comment has been minimized.

Copy link
Owner Author

commented Dec 23, 2018

Note: I use input_ to appease my editor, since input is a Python built-in. See https://stackoverflow.com/a/20670757/109618

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.