Skip to content

Instantly share code, notes, and snippets.

@ciela
Last active Jan 28, 2020
Embed
What would you like to do?
Getting familiar with Gradient Reversal Layer.
from typing import Tuple
import torch
import torch.nn as nn
class GradientReversalFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input_forward: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(scale)
return input_forward
@staticmethod
def backward(ctx, grad_backward: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale, = ctx.saved_tensors
return scale * -grad_backward, None
class GradientReversal(nn.Module):
def __init__(self, scale: float):
super(GradientReversal, self).__init__()
self.scale = torch.tensor(scale)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return GradientReversalFunction.apply(x, self.scale)
class BranchReversalSiameseNet(nn.Module):
def __init__(self, use_grl: bool = True):
super(BranchReversalSiameseNet, self).__init__()
self.use_grl = use_grl
self.shared = nn.Linear(1, 1, bias=False)
self.trunk_branch = nn.Linear(1, 1, bias=False)
self.grl = GradientReversal(scale=1.0)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
y_shared = self.shared(x)
y_trunk = self.trunk_branch(y_shared)
if self.use_grl:
y_shared = self.grl(y_shared)
y_branch = self.trunk_branch(y_shared)
return y_trunk, y_branch
print('===== BranchReversalSiameseNet Parameters ======')
torch.manual_seed(0)
model = BranchReversalSiameseNet()
print(f'Shared Weight: {list(model.shared.parameters())[0].data}')
print(f'Trunk/Branch Weight: {list(model.trunk_branch.parameters())[0].data}')
print()
print('===== BranchReversalSiameseNet w/o GRL =========')
torch.manual_seed(0)
model = BranchReversalSiameseNet(use_grl=False)
x = torch.randn(1)
print(f'Input: x={x.data}')
y_trunk, y_branch = model(x)
y_out = y_trunk + y_branch
print(f'Output: y_trunk={y_trunk.data}, y_branch={y_branch.data}, y_out={y_out.data}')
y_out.backward()
print(f'Shared Gradient: {list(model.shared.parameters())[0].grad}')
print()
print('===== BranchReversalSiameseNet w/ GRL ==========')
torch.manual_seed(0)
model = BranchReversalSiameseNet()
x = torch.randn(1)
print(f'Input: x={x.data}')
y_trunk, y_branch = model(x)
y_out = y_trunk + y_branch
print(f'Output: y_trunk={y_trunk.data}, y_branch={y_branch.data}, y_out={y_out.data}')
y_out.backward()
print(f'Shared Gradient: {list(model.shared.parameters())[0].grad}')
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment