Skip to content

Instantly share code, notes, and snippets.

@Naruto-Sasuke
Last active October 22, 2017 03:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Naruto-Sasuke/8a1400382c1a9c0785ad09853d711c8c to your computer and use it in GitHub Desktop.
Save Naruto-Sasuke/8a1400382c1a9c0785ad09853d711c8c to your computer and use it in GitHub Desktop.
import torch
from torch.autograd import Variable
class AdvFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
print(type(input) # torch.cuda.FloatTensor, why?
ctx.save_for_backward(input)
return output
@staticmethod
def backward(ctx, grad_output):
print(grad_output.requires_grad) # False! Why?
input, = ctx.saved_variables
print(input.requires_grad) # It is true.. Why?
return grad_input
# The following is for question three.
class AdvFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# do some operations and get the output
ctx.save_for_backward(input)
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.data
# do some operations on grad_output
grad_input = Variable(grad_output, requires_grad=False) # ?
return grad_input
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment