Skip to content

Instantly share code, notes, and snippets.

@janfreyberg
Created April 1, 2019 20:46
Show Gist options
  • Save janfreyberg/9d7c47a4c59b8733b9af779718738849 to your computer and use it in GitHub Desktop.
Save janfreyberg/9d7c47a4c59b8733b9af779718738849 to your computer and use it in GitHub Desktop.
Gist for blogpost: revgrad
from .functional import revgrad
from torch.nn import Module
class RevGrad(Module):
def __init__(self, *args, **kwargs):
"""
A gradient reversal layer.
This layer has no parameters, and simply reverses the gradient
in the backward pass.
"""
super().__init__(*args, **kwargs)
def forward(self, input_):
return revgrad(input_)
from torch.autograd import Function
class RevGrad(Function):
@staticmethod
def forward(ctx, input_):
ctx.save_for_backward(input_)
output = input_
return output
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = -grad_output
return grad_input
revgrad = RevGrad.apply
import copy
import torch
from pytorch_revgrad import RevGrad
def test_gradients_inverted():
network = torch.nn.Sequential(torch.nn.Linear(5, 3), torch.nn.Linear(3, 1))
revnetwork = torch.nn.Sequential(copy.deepcopy(network), RevGrad())
inp = torch.randn(8, 5)
outp = torch.randn(8)
criterion = torch.nn.MSELoss()
criterion(network(inp), outp).backward()
criterion(revnetwork(inp), outp).backward()
assert all(
(p1.grad == -p2.grad).all()
for p1, p2 in zip(network.parameters(), revnetwork.parameters())
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment