Skip to content

Instantly share code, notes, and snippets.

@BarclayII
Created November 1, 2017 03:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BarclayII/c7b5fd4be81c15d42c4fda23bbba6f03 to your computer and use it in GitHub Desktop.
Save BarclayII/c7b5fd4be81c15d42c4fda23bbba6f03 to your computer and use it in GitHub Desktop.
PyTorch `reinforce()` function sucks so I keep the alternative solution here
import torch as T
import numpy as np
x = T.autograd.Variable(T.randn(5, 8), requires_grad=True)
p = T.nn.functional.softmax(x)
y = p.multinomial()
y.reinforce(T.ones(y.size()))
y.backward()
d = x.grad.data.clone().numpy()
x.grad.data.zero_()
logp = T.nn.functional.log_softmax(x)
logp_selected = logp.gather(1, T.autograd.Variable(y.data))
logp_selected.backward(-T.ones(y.size())) # notice the minus
d2 = x.grad.data.clone().numpy()
assert np.all(np.abs(d - d2) < 1e-3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment