Skip to content

Instantly share code, notes, and snippets.

@chrischoy
Created July 5, 2017 21:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chrischoy/5bc1cc48f6118260652b41ac81969ac4 to your computer and use it in GitHub Desktop.
Save chrischoy/5bc1cc48f6118260652b41ac81969ac4 to your computer and use it in GitHub Desktop.
import torch
from torch.autograd import Variable, Function
class MultiplyAdd(Function):
@staticmethod
def forward(ctx, input1, scalar, input2):
ctx.scalar = scalar
ctx.test = {'test': 3}
return input1 + scalar * input2
@staticmethod
def backward(ctx, grad_output):
print(ctx.test)
return grad_output, None, ctx.scalar * grad_output
a = Variable(torch.randn(3, 5), requires_grad=True)
b = Variable(torch.randn(3, 5), requires_grad=True)
scalar = 2
ctx = {}
result = MultiplyAdd().apply(a, scalar, b)
g = Variable(torch.ones(1), requires_grad=True)
result.sum().backward(g)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment