Skip to content

Instantly share code, notes, and snippets.

@XinDongol
Created August 3, 2020 22:49
Show Gist options
  • Save XinDongol/7662686e5b6f4adf17765ac1a448ceb8 to your computer and use it in GitHub Desktop.
Save XinDongol/7662686e5b6f4adf17765ac1a448ceb8 to your computer and use it in GitHub Desktop.
import torch.nn as nn
import torch.nn.functional as F
import torch
class DeepInversionFeatureHook():
'''
Implementation of the forward hook to track feature statistics and compute a loss on them.
Will compute mean and variance, and will use l2 as a loss
'''
def __init__(self, module):
self.hook = module.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
# hook co compute deepinversion's feature distribution regularization
nch = input[0].shape[1]
mean = input[0].mean([0, 2, 3])
var = input[0].permute(1, 0, 2, 3).contiguous().view(
[nch, -1]).var(1, unbiased=False)
# forcing mean and variance to match between two distributions
# other ways might work better, i.g. KL divergence
r_feature = torch.norm(module.running_var.data.detach() - var, 2) + torch.norm(
module.running_mean.data.detach() - mean, 2)
self.r_feature = r_feature
# must have no output
def close(self):
self.hook.remove()
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(64)
# regist hooks
self.loss_r_feature_layers = []
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
self.loss_r_feature_layers.append(
DeepInversionFeatureHook(module))
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.conv3(x)
x = self.bn3(x)
loss_r_feature = sum([mod.r_feature
for (idx, mod) in enumerate(self.loss_r_feature_layers)])
return x, loss_r_feature
net = nn.DataParallel(MyNet().cuda())
# net = MyNet().cuda()
for i in range(10):
print('========> %d iteration.' % i)
output, extra_loss = net(torch.randn(512, 3, 32, 32).cuda())
print('=> extra_loss:', extra_loss.size(), extra_loss.device, extra_loss)
print('=> extra_loss sum:', extra_loss.sum().size(), extra_loss.sum().device, extra_loss.sum())
# loss = F.mse_loss(output, torch.ones_like(output))
# print('=> mse loss:', loss.size(), loss.device, loss)
print('=> output:', output.size(), output.device, output.grad_fn)
loss = extra_loss.sum()
loss.backward()
net.zero_grad()
@XinDongol
Copy link
Author

I also tried.

import torch.nn as nn
import torch.nn.functional as F
import torch
class DeepInversionFeatureHook():
    '''
    Implementation of the forward hook to track feature statistics and compute a loss on them.
    Will compute mean and variance, and will use l2 as a loss
    '''
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        # hook co compute deepinversion's feature distribution regularization
        nch = input[0].shape[1]
        mean = input[0].mean([0, 2, 3])
        var = input[0].permute(1, 0, 2, 3).contiguous().view(
            [nch, -1]).var(1, unbiased=False)
        # forcing mean and variance to match between two distributions
        # other ways might work better, i.g. KL divergence
        r_feature = torch.norm(module.running_var.data.detach() - var, 2) + torch.norm(
            module.running_mean.data.detach() - mean, 2)
        self.r_feature = r_feature
        # must have no output
    def close(self):
        self.hook.remove()
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(64)
        # regist hooks
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        return x
net = nn.DataParallel(MyNet().cuda())
loss_r_feature_layers = []
for module in net.modules():
    if isinstance(module, nn.BatchNorm2d):
        loss_r_feature_layers.append(
            DeepInversionFeatureHook(module))
# net = MyNet().cuda()
for i in range(10):
    print('========> %d iteration.' % i)
    output = net(torch.randn(512, 3, 32, 32).cuda())
    extra_loss = sum([mod.r_feature
                    for (idx, mod) in enumerate(loss_r_feature_layers)])
    print('=> extra_loss:', extra_loss.size(), extra_loss.device, extra_loss)
    print('=> extra_loss sum:', extra_loss.sum().size(), extra_loss.sum().device, extra_loss.sum())
    # loss = F.mse_loss(output, torch.ones_like(output))
    # print('=> mse loss:', loss.size(), loss.device, loss)
    print('=> output:', output.size(), output.device, output.grad_fn)
    loss = extra_loss.sum()
    loss.backward()
    net.zero_grad()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment