Created
July 17, 2019 01:54
-
-
Save Hanrui-Wang/bf225dc0ccb91cdce160539c0acc853a to your computer and use it in GitHub Desktop.
how to write customized backward function in pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MyReLU(torch.autograd.Function): | |
""" | |
We can implement our own custom autograd Functions by subclassing | |
torch.autograd.Function and implementing the forward and backward passes | |
which operate on Tensors. | |
""" | |
@staticmethod | |
def forward(ctx, input): | |
""" | |
In the forward pass we receive a Tensor containing the input and return | |
a Tensor containing the output. ctx is a context object that can be used | |
to stash information for backward computation. You can cache arbitrary | |
objects for use in the backward pass using the ctx.save_for_backward method. | |
""" | |
ctx.save_for_backward(input) | |
return input.clamp(min=0) | |
@staticmethod | |
def backward(ctx, grad_output): | |
""" | |
In the backward pass we receive a Tensor containing the gradient of the loss | |
with respect to the output, and we need to compute the gradient of the loss | |
with respect to the input. | |
""" | |
input, = ctx.saved_tensors | |
grad_input = grad_output.clone() | |
grad_input[input < 0] = 0 | |
return grad_input |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nicely done, cleared my doubt, thanks!