Skip to content

Instantly share code, notes, and snippets.

@bartvm
Created July 15, 2017 03:34
Show Gist options
  • Save bartvm/0048f85603a4f4ae209a38ec7be70c50 to your computer and use it in GitHub Desktop.
Save bartvm/0048f85603a4f4ae209a38ec7be70c50 to your computer and use it in GitHub Desktop.
from torch import Tensor
from torch.autograd import Function, Variable
class Foo(Function):
def forward(self, x):
return x
def backward(self, dz_star):
dphidz = Variable(Tensor(1), requires_grad=True)
dphidz.backward()
return None
x = Variable(Tensor(1), requires_grad=True)
Foo()(x).backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment