Last active
January 18, 2020 13:34
-
-
Save knsong/1b2aeaf1cdfef28d52138df5d05cd949 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import product | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import autograd | |
from torch.autograd.function import once_differentiable | |
class my_func(autograd.Function): | |
@staticmethod | |
def forward(ctx, x, scale=2.0): | |
ctx.scale = scale | |
y = x * scale | |
return y | |
@staticmethod | |
@once_differentiable | |
def backward(ctx, dy): | |
print('normal backward') | |
dx = dy * ctx.scale | |
return dx, None | |
class my_func_inplace(autograd.Function): | |
@staticmethod | |
def forward(ctx, x, scale=2.0): | |
ctx.scale = scale | |
x.mul_(scale) | |
ctx.mark_dirty(x) | |
ctx.save_for_backward(x) | |
return x | |
@staticmethod | |
@once_differentiable | |
def backward(ctx, dy): | |
y, = ctx.saved_tensors | |
y.zero_() | |
#dy = dy.clone() # if open this, there will be no gradient error | |
print('inplace backward') | |
dy.mul_(ctx.scale) | |
#dy.zero_() | |
return dy, None | |
class SimpleNet(nn.Module): | |
expansion = 1 | |
def __init__(self, in_planes, planes): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | |
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | |
def forward(self, x): | |
out = self.conv1(x) | |
out.register_hook(lambda grad: print('conv1 out grad', grad)) | |
out = self.myfunc(out, 2) | |
out.register_hook(lambda grad: print('myfunc1 out grad', grad)) | |
out = self.conv2(out) | |
out.register_hook(lambda grad: print('conv2 out grad', grad)) | |
out = self.myfunc(out, 2) | |
out.register_hook(lambda grad: print('myfunc2 out grad', grad)) | |
out = F.relu(out, inplace=False) | |
out.register_hook(lambda grad: print('relu out grad', grad)) | |
out = F.avg_pool2d(out, 10) | |
out.register_hook(lambda grad: print('avg_pool2d out grad', grad)) | |
return out | |
class BasicBlock(nn.Module): | |
expansion = 1 | |
def __init__(self, in_planes, planes, stride=1): | |
super(BasicBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(planes) | |
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | |
self.shortcut = nn.Sequential() | |
if stride != 1 or in_planes != self.expansion*planes: | |
self.shortcut = nn.Sequential( | |
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), | |
nn.BatchNorm2d(self.expansion*planes) | |
) | |
self.bn2 = nn.BatchNorm2d(planes) | |
else: | |
#self.n2 = nn.BatchNorm2d(planes) #use name `n2` to keep it `nn.BatchNorm2d` when `len(self.shortcut) == 0` | |
#can avoid the gradient error, but why? | |
self.bn2 = nn.BatchNorm2d(planes) #in this case, 'bn2' will be repaced with `my_func*` when `len(self.shortcut) == 0` | |
# and will cause the gradient error. | |
def forward(self, x): | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.conv2(out) | |
if len(self.shortcut): | |
out = self.bn2(out) | |
else: | |
out = self.bn2(out) | |
#out = self.n2(out) | |
out = out + self.shortcut(x) | |
out = F.relu(out, inplace=True) | |
return out | |
class ResNet(nn.Module): | |
def __init__(self, block, num_blocks, num_classes=10): | |
super(ResNet, self).__init__() | |
self.in_planes = 64 | |
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(64) | |
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) | |
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) | |
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) | |
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) | |
self.linear = nn.Linear(512*block.expansion, num_classes) | |
def _make_layer(self, block, planes, num_blocks, stride): | |
strides = [stride] + [1]*(num_blocks-1) | |
layers = [] | |
for stride in strides: | |
layers.append(block(self.in_planes, planes, stride)) | |
self.in_planes = planes * block.expansion | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.layer1(out) | |
out = self.layer2(out) | |
out = self.layer3(out) | |
out = self.layer4(out) | |
out = F.avg_pool2d(out, 4) | |
out = out.view(out.size(0), -1) | |
out = self.linear(out) | |
return out | |
def module_hook_fn(m, i, o): | |
print(m) | |
def patch_net(net, target_func, prefix=''): | |
for name, layer in net._modules.items(): | |
if isinstance(layer, nn.Sequential): | |
patch_net(layer, target_func, prefix=prefix + '.' + name) | |
elif isinstance(layer, BasicBlock): | |
patch_net(layer, target_func, prefix=prefix + '.' + name) | |
elif 'bn' in name: | |
net._modules[name] = target_func | |
print('{} changed to {}'.format(prefix + '.' + name, net._modules[name])) | |
def get_all_layers(net): | |
for name, layer in net._modules.items(): | |
if isinstance(layer, nn.Sequential): | |
get_all_layers(layer) | |
elif isinstance(layer, BasicBlock): | |
get_all_layers(layer) | |
else: | |
print(type(layer)) | |
layer.register_forward_hook(module_hook_fn) | |
layer.register_backward_hook(module_hook_fn) | |
def ResNet18(): | |
net = ResNet(BasicBlock, [2,2,2,2]) | |
#get_all_layers(net) | |
return net | |
net = ResNet18() | |
print('start no inplace fwd&bwd') | |
patch_net(net, my_func.apply) | |
print(net._modules) | |
x1 = torch.ones((2, 3, 32, 32), requires_grad=True) | |
net_out1 = net(x1) | |
net_out1.backward(torch.ones_like(net_out1)) | |
print('start inplace fwd&bwd') | |
patch_net(net, my_func_inplace.apply) | |
print(net._modules) | |
x2 = torch.ones((2, 3, 32, 32), requires_grad=True) | |
net_out2 = net(x2) | |
net_out2.backward(torch.ones_like(net_out2)) | |
print('grad x error max:', (x1.grad - x2.grad).abs_().max()) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment