Skip to content

Instantly share code, notes, and snippets.

@Hanrui-Wang
Created July 17, 2019 01:54
Show Gist options
  • Save Hanrui-Wang/bf225dc0ccb91cdce160539c0acc853a to your computer and use it in GitHub Desktop.
Save Hanrui-Wang/bf225dc0ccb91cdce160539c0acc853a to your computer and use it in GitHub Desktop.
how to write customized backward function in pytorch
class MyReLU(torch.autograd.Function):
"""
We can implement our own custom autograd Functions by subclassing
torch.autograd.Function and implementing the forward and backward passes
which operate on Tensors.
"""
@staticmethod
def forward(ctx, input):
"""
In the forward pass we receive a Tensor containing the input and return
a Tensor containing the output. ctx is a context object that can be used
to stash information for backward computation. You can cache arbitrary
objects for use in the backward pass using the ctx.save_for_backward method.
"""
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
"""
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
@inder-preet-kakkar
Copy link

Nicely done, cleared my doubt, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment