Skip to content

Instantly share code, notes, and snippets.

Created August 29, 2017 14:16
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save anonymous/866455c6bdb88403d23d8c862b3b9233 to your computer and use it in GitHub Desktop.
Save anonymous/866455c6bdb88403d23d8c862b3b9233 to your computer and use it in GitHub Desktop.
import torch
from torch.autograd import Variable, Function
class Linear(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
print('hello')
input, weight, bias = ctx.saved_variables
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_bias
linear = Linear.apply
from torch.autograd import gradcheck
# gradchek takes a tuple of tensor as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (Variable(torch.randn(20,20).double(), requires_grad=True), Variable(torch.randn(30,20).double(), requires_grad=True),)
test = gradcheck(Linear.apply, input, eps=1e-6, atol=1e-4)
print(test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment