Skip to content

Instantly share code, notes, and snippets.

@knsong
Last active January 18, 2020 13:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save knsong/1b2aeaf1cdfef28d52138df5d05cd949 to your computer and use it in GitHub Desktop.
Save knsong/1b2aeaf1cdfef28d52138df5d05cd949 to your computer and use it in GitHub Desktop.
from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd
from torch.autograd.function import once_differentiable
class my_func(autograd.Function):
@staticmethod
def forward(ctx, x, scale=2.0):
ctx.scale = scale
y = x * scale
return y
@staticmethod
@once_differentiable
def backward(ctx, dy):
print('normal backward')
dx = dy * ctx.scale
return dx, None
class my_func_inplace(autograd.Function):
@staticmethod
def forward(ctx, x, scale=2.0):
ctx.scale = scale
x.mul_(scale)
ctx.mark_dirty(x)
ctx.save_for_backward(x)
return x
@staticmethod
@once_differentiable
def backward(ctx, dy):
y, = ctx.saved_tensors
y.zero_()
#dy = dy.clone() # if open this, there will be no gradient error
print('inplace backward')
dy.mul_(ctx.scale)
#dy.zero_()
return dy, None
class SimpleNet(nn.Module):
expansion = 1
def __init__(self, in_planes, planes):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
def forward(self, x):
out = self.conv1(x)
out.register_hook(lambda grad: print('conv1 out grad', grad))
out = self.myfunc(out, 2)
out.register_hook(lambda grad: print('myfunc1 out grad', grad))
out = self.conv2(out)
out.register_hook(lambda grad: print('conv2 out grad', grad))
out = self.myfunc(out, 2)
out.register_hook(lambda grad: print('myfunc2 out grad', grad))
out = F.relu(out, inplace=False)
out.register_hook(lambda grad: print('relu out grad', grad))
out = F.avg_pool2d(out, 10)
out.register_hook(lambda grad: print('avg_pool2d out grad', grad))
return out
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
self.bn2 = nn.BatchNorm2d(planes)
else:
#self.n2 = nn.BatchNorm2d(planes) #use name `n2` to keep it `nn.BatchNorm2d` when `len(self.shortcut) == 0`
#can avoid the gradient error, but why?
self.bn2 = nn.BatchNorm2d(planes) #in this case, 'bn2' will be repaced with `my_func*` when `len(self.shortcut) == 0`
# and will cause the gradient error.
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.conv2(out)
if len(self.shortcut):
out = self.bn2(out)
else:
out = self.bn2(out)
#out = self.n2(out)
out = out + self.shortcut(x)
out = F.relu(out, inplace=True)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512*block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def module_hook_fn(m, i, o):
print(m)
def patch_net(net, target_func, prefix=''):
for name, layer in net._modules.items():
if isinstance(layer, nn.Sequential):
patch_net(layer, target_func, prefix=prefix + '.' + name)
elif isinstance(layer, BasicBlock):
patch_net(layer, target_func, prefix=prefix + '.' + name)
elif 'bn' in name:
net._modules[name] = target_func
print('{} changed to {}'.format(prefix + '.' + name, net._modules[name]))
def get_all_layers(net):
for name, layer in net._modules.items():
if isinstance(layer, nn.Sequential):
get_all_layers(layer)
elif isinstance(layer, BasicBlock):
get_all_layers(layer)
else:
print(type(layer))
layer.register_forward_hook(module_hook_fn)
layer.register_backward_hook(module_hook_fn)
def ResNet18():
net = ResNet(BasicBlock, [2,2,2,2])
#get_all_layers(net)
return net
net = ResNet18()
print('start no inplace fwd&bwd')
patch_net(net, my_func.apply)
print(net._modules)
x1 = torch.ones((2, 3, 32, 32), requires_grad=True)
net_out1 = net(x1)
net_out1.backward(torch.ones_like(net_out1))
print('start inplace fwd&bwd')
patch_net(net, my_func_inplace.apply)
print(net._modules)
x2 = torch.ones((2, 3, 32, 32), requires_grad=True)
net_out2 = net(x2)
net_out2.backward(torch.ones_like(net_out2))
print('grad x error max:', (x1.grad - x2.grad).abs_().max())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment