Skip to content

Instantly share code, notes, and snippets.

@rohit-gupta
Created April 7, 2020 20:21
Show Gist options
  • Save rohit-gupta/20dd994332b35c71be6b497f4f4a1dc7 to your computer and use it in GitHub Desktop.
Save rohit-gupta/20dd994332b35c71be6b497f4f4a1dc7 to your computer and use it in GitHub Desktop.
Reverse Gradients in PyTorch
from torch.autograd import Function
class GradientReversal(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
output = x
return output
@staticmethod
def backward(ctx, grad_output):
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.neg()
return grad_input
reverse_gradients = GradientReversal.apply
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment