Skip to content

Instantly share code, notes, and snippets.

@XinDongol
Created August 3, 2020 22:49
Show Gist options
  • Save XinDongol/7662686e5b6f4adf17765ac1a448ceb8 to your computer and use it in GitHub Desktop.
Save XinDongol/7662686e5b6f4adf17765ac1a448ceb8 to your computer and use it in GitHub Desktop.
import torch.nn as nn
import torch.nn.functional as F
import torch
class DeepInversionFeatureHook():
'''
Implementation of the forward hook to track feature statistics and compute a loss on them.
Will compute mean and variance, and will use l2 as a loss
'''
def __init__(self, module):
self.hook = module.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
# hook co compute deepinversion's feature distribution regularization
nch = input[0].shape[1]
mean = input[0].mean([0, 2, 3])
var = input[0].permute(1, 0, 2, 3).contiguous().view(
[nch, -1]).var(1, unbiased=False)
# forcing mean and variance to match between two distributions
# other ways might work better, i.g. KL divergence
r_feature = torch.norm(module.running_var.data.detach() - var, 2) + torch.norm(
module.running_mean.data.detach() - mean, 2)
self.r_feature = r_feature
# must have no output
def close(self):
self.hook.remove()
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(64)
# regist hooks
self.loss_r_feature_layers = []
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
self.loss_r_feature_layers.append(
DeepInversionFeatureHook(module))
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.conv3(x)
x = self.bn3(x)
loss_r_feature = sum([mod.r_feature
for (idx, mod) in enumerate(self.loss_r_feature_layers)])
return x, loss_r_feature
net = nn.DataParallel(MyNet().cuda())
# net = MyNet().cuda()
for i in range(10):
print('========> %d iteration.' % i)
output, extra_loss = net(torch.randn(512, 3, 32, 32).cuda())
print('=> extra_loss:', extra_loss.size(), extra_loss.device, extra_loss)
print('=> extra_loss sum:', extra_loss.sum().size(), extra_loss.sum().device, extra_loss.sum())
# loss = F.mse_loss(output, torch.ones_like(output))
# print('=> mse loss:', loss.size(), loss.device, loss)
print('=> output:', output.size(), output.device, output.grad_fn)
loss = extra_loss.sum()
loss.backward()
net.zero_grad()
@XinDongol
Copy link
Author

========> 0 iteration.
=> extra_loss: torch.Size([8]) cuda:0 tensor([14.7952, 14.8032, 14.8081, 14.8076, 14.8126, 14.8142, 14.8115, 14.8145],
       device='cuda:0', grad_fn=<GatherBackward>)
=> extra_loss sum: torch.Size([]) cuda:0 tensor(118.4669, device='cuda:0', grad_fn=<SumBackward0>)
=> output: torch.Size([512, 64, 32, 32]) cuda:0 <torch.autograd.function.GatherBackward object at 0x7f48e2434ac8>
/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-50b975d2f86d> in <module>
     79     loss = extra_loss.sum()
     80 
---> 81     loss.backward()
     82     net.zero_grad()
/opt/conda/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    196                 products. Defaults to ``False``.
    197         """
--> 198         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    199 
    200     def register_hook(self, hook):
/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     98     Variable._execution_engine.run_backward(
     99         tensors, grad_tensors, retain_graph, create_graph,
--> 100         allow_unreachable=True)  # allow_unreachable flag
    101 
    102 
RuntimeError: Function AddBackward0 returned an invalid gradient at index 1 - expected device cuda:2 but got cuda:6 (validate_outputs at /pytorch/torch/csrc/autograd/engine.cpp:491)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x46 (0x7f48cc5c8536 in /opt/conda/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x2d8503b (0x7f48b3fe903b in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #2: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0x548 (0x7f48b3fe9d58 in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #3: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x3d2 (0x7f48b3febce2 in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #4: torch::autograd::Engine::thread_init(int) + 0x39 (0x7f48b3fe4359 in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #5: torch::autograd::python::PythonEngine::thread_init(int) + 0x38 (0x7f48cd1474d8 in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0xc819d (0x7f48dfc8319d in /opt/conda/bin/../lib/libstdc++.so.6)
frame #7: <unknown function> + 0x76db (0x7f48e33b36db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #8: clone + 0x3f (0x7f48e30dc88f in /lib/x86_64-linux-gnu/libc.so.6)

@XinDongol
Copy link
Author

I also tried.

import torch.nn as nn
import torch.nn.functional as F
import torch
class DeepInversionFeatureHook():
    '''
    Implementation of the forward hook to track feature statistics and compute a loss on them.
    Will compute mean and variance, and will use l2 as a loss
    '''
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        # hook co compute deepinversion's feature distribution regularization
        nch = input[0].shape[1]
        mean = input[0].mean([0, 2, 3])
        var = input[0].permute(1, 0, 2, 3).contiguous().view(
            [nch, -1]).var(1, unbiased=False)
        # forcing mean and variance to match between two distributions
        # other ways might work better, i.g. KL divergence
        r_feature = torch.norm(module.running_var.data.detach() - var, 2) + torch.norm(
            module.running_mean.data.detach() - mean, 2)
        self.r_feature = r_feature
        # must have no output
    def close(self):
        self.hook.remove()
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(64)
        # regist hooks
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        return x
net = nn.DataParallel(MyNet().cuda())
loss_r_feature_layers = []
for module in net.modules():
    if isinstance(module, nn.BatchNorm2d):
        loss_r_feature_layers.append(
            DeepInversionFeatureHook(module))
# net = MyNet().cuda()
for i in range(10):
    print('========> %d iteration.' % i)
    output = net(torch.randn(512, 3, 32, 32).cuda())
    extra_loss = sum([mod.r_feature
                    for (idx, mod) in enumerate(loss_r_feature_layers)])
    print('=> extra_loss:', extra_loss.size(), extra_loss.device, extra_loss)
    print('=> extra_loss sum:', extra_loss.sum().size(), extra_loss.sum().device, extra_loss.sum())
    # loss = F.mse_loss(output, torch.ones_like(output))
    # print('=> mse loss:', loss.size(), loss.device, loss)
    print('=> output:', output.size(), output.device, output.grad_fn)
    loss = extra_loss.sum()
    loss.backward()
    net.zero_grad()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment