Skip to content

Instantly share code, notes, and snippets.

@ciela

ciela/grl.py

Last active Jan 28, 2020
Embed
What would you like to do?
Getting familiar with Gradient Reversal Layer.
from typing import Tuple
import torch
import torch.nn as nn
class GradientReversalFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input_forward: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(scale)
return input_forward
@staticmethod
def backward(ctx, grad_backward: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale, = ctx.saved_tensors
return scale * -grad_backward, None
class GradientReversal(nn.Module):
def __init__(self, scale: float):
super(GradientReversal, self).__init__()
self.scale = torch.tensor(scale)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return GradientReversalFunction.apply(x, self.scale)
class BranchReversalSiameseNet(nn.Module):
def __init__(self, use_grl: bool = True):
super(BranchReversalSiameseNet, self).__init__()
self.use_grl = use_grl
self.shared = nn.Linear(1, 1, bias=False)
self.trunk_branch = nn.Linear(1, 1, bias=False)
self.grl = GradientReversal(scale=1.0)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
y_shared = self.shared(x)
y_trunk = self.trunk_branch(y_shared)
if self.use_grl:
y_shared = self.grl(y_shared)
y_branch = self.trunk_branch(y_shared)
return y_trunk, y_branch
print('===== BranchReversalSiameseNet Parameters ======')
torch.manual_seed(0)
model = BranchReversalSiameseNet()
print(f'Shared Weight: {list(model.shared.parameters())[0].data}')
print(f'Trunk/Branch Weight: {list(model.trunk_branch.parameters())[0].data}')
print()
print('===== BranchReversalSiameseNet w/o GRL =========')
torch.manual_seed(0)
model = BranchReversalSiameseNet(use_grl=False)
x = torch.randn(1)
print(f'Input: x={x.data}')
y_trunk, y_branch = model(x)
y_out = y_trunk + y_branch
print(f'Output: y_trunk={y_trunk.data}, y_branch={y_branch.data}, y_out={y_out.data}')
y_out.backward()
print(f'Shared Gradient: {list(model.shared.parameters())[0].grad}')
print()
print('===== BranchReversalSiameseNet w/ GRL ==========')
torch.manual_seed(0)
model = BranchReversalSiameseNet()
x = torch.randn(1)
print(f'Input: x={x.data}')
y_trunk, y_branch = model(x)
y_out = y_trunk + y_branch
print(f'Output: y_trunk={y_trunk.data}, y_branch={y_branch.data}, y_out={y_out.data}')
y_out.backward()
print(f'Shared Gradient: {list(model.shared.parameters())[0].grad}')
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.