Skip to content

Instantly share code, notes, and snippets.

@jinyup100
Last active August 5, 2020 19:09
Show Gist options
  • Save jinyup100/bc2fb2d25ac5ac1ac635c9f2b62853d7 to your computer and use it in GitHub Desktop.
Save jinyup100/bc2fb2d25ac5ac1ac635c9f2b62853d7 to your computer and use it in GitHub Desktop.
import numpy as np
import os
import onnx
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
# Class for the Building Blocks required for ResNet
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1,
downsample=None, dilation=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
padding = 2 - stride
if downsample is not None and dilation > 1:
dilation = dilation // 2
padding = dilation
assert stride == 1 or dilation == 1, \
"stride and dilation must have one equals to zero at least"
if dilation > 1:
padding = dilation
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=padding, bias=False, dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
# End of Building Blocks
# Class for ResNet - the Backbone neural network
class ResNet(nn.Module):
"ResNET"
def __init__(self, block, layers, used_layers):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, # 3
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.feature_size = 128 * block.expansion
self.used_layers = used_layers
layer3 = True if 3 in used_layers else False
layer4 = True if 4 in used_layers else False
if layer3:
self.layer3 = self._make_layer(block, 256, layers[2],
stride=1, dilation=2) # 15x15, 7x7
self.feature_size = (256 + 128) * block.expansion
else:
self.layer3 = lambda x: x # identity
if layer4:
self.layer4 = self._make_layer(block, 512, layers[3],
stride=1, dilation=4) # 7x7, 3x3
self.feature_size = 512 * block.expansion
else:
self.layer4 = lambda x: x # identity
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, np.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
dd = dilation
if stride != 1 or self.inplanes != planes * block.expansion:
if stride == 1 and dilation == 1:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
else:
if dilation > 1:
dd = dilation // 2
padding = dd
else:
dd = 1
padding = 0
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=3, stride=stride, bias=False,
padding=padding, dilation=dd),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride,
downsample, dilation=dilation))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x_ = self.relu(x)
x = self.maxpool(x_)
p1 = self.layer1(x)
p2 = self.layer2(p1)
p3 = self.layer3(p2)
p4 = self.layer4(p3)
out = [x_, p1, p2, p3, p4]
out = [out[i] for i in self.used_layers]
if len(out) == 1:
return out[0]
else:
return out
# End of ResNet
# Class for Adjusting the layers of the neural net
class AdjustLayer_1(nn.Module):
def __init__(self, in_channels, out_channels, center_size=7):
super(AdjustLayer_1, self).__init__()
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
)
self.center_size = center_size
def forward(self, x):
x = self.downsample(x)
l = 4
r = 11
x = x[:, :, l:r, l:r]
return x
class AdjustAllLayer_1(nn.Module):
def __init__(self, in_channels, out_channels, center_size=7):
super(AdjustAllLayer_1, self).__init__()
self.num = len(out_channels)
if self.num == 1:
self.downsample = AdjustLayer_1(in_channels[0],
out_channels[0],
center_size)
else:
for i in range(self.num):
self.add_module('downsample'+str(i+2),
AdjustLayer_1(in_channels[i],
out_channels[i],
center_size))
def forward(self, features):
if self.num == 1:
return self.downsample(features)
else:
out = []
for i in range(self.num):
adj_layer = getattr(self, 'downsample'+str(i+2))
out.append(adj_layer(features[i]))
return out
class AdjustLayer_2(nn.Module):
def __init__(self, in_channels, out_channels, center_size=7):
super(AdjustLayer_2, self).__init__()
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
)
self.center_size = center_size
def forward(self, x):
x = self.downsample(x)
#l = 3
#r = 10
#x = x[:, :, l:r, l:r]
return x
class AdjustAllLayer_2(nn.Module):
def __init__(self, in_channels, out_channels, center_size=7):
super(AdjustAllLayer_2, self).__init__()
self.num = len(out_channels)
if self.num == 1:
self.downsample = AdjustLayer_2(in_channels[0],
out_channels[0],
center_size)
else:
for i in range(self.num):
self.add_module('downsample'+str(i+2),
AdjustLayer_2(in_channels[i],
out_channels[i],
center_size))
def forward(self, features):
if self.num == 1:
return self.downsample(features)
else:
out = []
for i in range(self.num):
adj_layer = getattr(self, 'downsample'+str(i+2))
out.append(adj_layer(features[i]))
return out
# End of Class for Adjusting the layers of the neural net
# Class for Region Proposal Neural Network
class RPN(nn.Module):
"Region Proposal Network"
def __init__(self):
super(RPN, self).__init__()
def forward(self, z_f, x_f):
raise NotImplementedError
class DepthwiseXCorr(nn.Module):
"Depthwise Correlation Layer"
def __init__(self, in_channels, hidden, out_channels, kernel_size=3, hidden_kernel_size=5):
super(DepthwiseXCorr, self).__init__()
self.conv_kernel = nn.Sequential(
nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False),
nn.BatchNorm2d(hidden),
nn.ReLU(inplace=True),
)
self.conv_search = nn.Sequential(
nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False),
nn.BatchNorm2d(hidden),
nn.ReLU(inplace=True),
)
self.head = nn.Sequential(
nn.Conv2d(hidden, hidden, kernel_size=1, bias=False),
nn.BatchNorm2d(hidden),
nn.ReLU(inplace=True),
nn.Conv2d(hidden, out_channels, kernel_size=1)
)
def forward(self, kernel, search):
kernel = self.conv_kernel(kernel)
search = self.conv_search(search)
feature = xcorr_depthwise(search, kernel)
out = self.head(feature)
return out
class DepthwiseRPN(RPN):
def __init__(self, anchor_num=5, in_channels=256, out_channels=256):
super(DepthwiseRPN, self).__init__()
self.cls = DepthwiseXCorr(in_channels, out_channels, 2 * anchor_num)
self.loc = DepthwiseXCorr(in_channels, out_channels, 4 * anchor_num)
def forward(self, z_f, x_f):
cls = self.cls(z_f, x_f)
loc = self.loc(z_f, x_f)
return cls, loc
class MultiRPN(RPN):
def __init__(self, anchor_num, in_channels):
super(MultiRPN, self).__init__()
for i in range(len(in_channels)):
self.add_module('rpn'+str(i+2),
DepthwiseRPN(anchor_num, in_channels[i], in_channels[i]))
self.weight_cls = nn.Parameter(torch.Tensor([0.38156851768108546, 0.4364767608115956, 0.18195472150731892]))
self.weight_loc = nn.Parameter(torch.Tensor([0.17644893463361863, 0.16564198028417967, 0.6579090850822015]))
def forward(self, z_fs, x_fs):
cls = []
loc = []
rpn2 = self.rpn2
z_f2 = z_fs[0]
x_f2 = x_fs[0]
c2,l2 = rpn2(z_f2, x_f2)
cls.append(c2)
loc.append(l2)
rpn3 = self.rpn3
z_f3 = z_fs[1]
x_f3 = x_fs[1]
c3,l3 = rpn3(z_f3, x_f3)
cls.append(c3)
loc.append(l3)
rpn4 = self.rpn4
z_f4 = z_fs[2]
x_f4 = x_fs[2]
c4,l4 = rpn4(z_f4, x_f4)
cls.append(c4)
loc.append(l4)
def avg(lst):
return sum(lst) / len(lst)
def weighted_avg(lst, weight):
s = 0
fixed_len = 3
for i in range(3):
s += lst[i] * weight[i]
return s
return weighted_avg(cls, self.weight_cls), weighted_avg(loc, self.weight_loc)
# End of class for RPN
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, bias=False, dilation=dilation)
def xcorr_depthwise(x, kernel):
"""
Deptwise convolution for input and weights with the same shapes
Elementwise multiplication -> GlobalAveragePooling -> scalar mul on (kernel_h * kernel_w)
"""
batch = kernel.size(0)
channel = kernel.size(1)
x = x.view(1, batch*channel, x.size(2), x.size(3))
kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3))
conv = nn.Conv2d(batch*channel, batch*channel, kernel_size=(kernel.size(2), kernel.size(3)), bias=False, groups=batch*channel)
conv.weight = nn.Parameter(kernel)
out = conv(x)
out = out.view(batch, channel, out.size(2), out.size(3))
out = out.detach()
return out
class ResNetBuilder(nn.Module):
def __init__(self):
super(ResNetBuilder, self).__init__()
# build backbone
self.backbone = ResNet(Bottleneck, [3, 4, 6, 3], [2, 3, 4])
def forward(self,frame):
""" only used in training
"""
# get feature
output = self.backbone(frame)
return output
class AdjustedLayerBuilder_1(nn.Module):
def __init__(self):
super(AdjustedLayerBuilder_1, self).__init__()
# Build Adjusted Layer Builder
self.neck = AdjustAllLayer_1([512, 1024, 2048], [256, 256, 256])
def forward(self, inp):
""" only used in training
"""
# Get Feature
output = self.neck(inp)
return output
class AdjustedLayerBuilder_2(nn.Module):
def __init__(self):
super(AdjustedLayerBuilder_2, self).__init__()
# Build Adjusted Layer Builder
self.neck = AdjustAllLayer_2([512, 1024, 2048], [256, 256, 256])
def forward(self, inp):
""" only used in training
"""
# Get Feature
output = self.neck(inp)
return output
class RPNBuilder(nn.Module):
def __init__(self):
super(RPNBuilder, self).__init__()
# Build Adjusted Layer Builder
self.rpn_head = MultiRPN(anchor_num=5,in_channels=[256, 256, 256])
def forward(self, zf, xf):
# Get Feature
cls, loc = self.rpn_head(zf, xf)
return cls, loc
"Load path should be the directory of the pre-trained siamrpn_r50_l234_dwxcorr.pth"
"The download link to siamrpn_r50_l234_dwxcorr.pth is shown in the description"
# Pre-trained Weights to the Tracker Model
current_path = os.getcwd()
load_path = os.path.join(current_path, "siamrpn_r50_l234_dwxcorr.pth")
pretrained_dict = torch.load(load_path,map_location=torch.device('cpu') )
pretrained_dict_backbone = pretrained_dict
pretrained_dict_neck_1 = pretrained_dict
pretrained_dict_neck_2 = pretrained_dict
pretrained_dict_head = pretrained_dict
# Dummmy Inputs for the torch backbone model
#target = torch.Tensor(np.random.rand(1,3,127,127))
#search = torch.Tensor(np.random.rand(1,3,125,125))
target = np.load('numpy_z_crop.npy')
search = np.load('numpy_x_crop.npy')
target = torch.Tensor(target)
search = torch.Tensor(search)
# Build the torch backbone model
backbone = ResNetBuilder()
backbone.eval()
backbone.state_dict().keys()
backbone_dict = backbone.state_dict()
# Load the pre-trained weight to the torch backbone model
pretrained_dict_backbone = {k: v for k, v in pretrained_dict_backbone.items() if k in backbone_dict}
backbone_dict.update(pretrained_dict_backbone)
backbone.load_state_dict(backbone_dict)
# Export the torch backbone model to ONNX model (one for target input, one for search input)
torch.onnx.export(backbone, target, "resnet_target.onnx", export_params=True, opset_version=11,
do_constant_folding=True, input_names = ['input'], output_names = ['output_1', 'output_2', 'output_3'])
torch.onnx.export(backbone, search, "resnet_search.onnx", export_params=True, opset_version=11,
do_constant_folding=True, input_names = ['input'], output_names = ['output_1', 'output_2', 'output_3'])
# Load the saved torch backbone model (for target input) using ONNX
onnx_resnet_target = onnx.load("resnet_target.onnx")
# Check whether the ONNX backbone model has been successfully imported
onnx.checker.check_model(onnx_resnet_target)
print(onnx.checker.check_model(onnx_resnet_target))
onnx.helper.printable_graph(onnx_resnet_target.graph)
print(onnx.helper.printable_graph(onnx_resnet_target.graph))
# Load the saved torch backbone model (for target input) using ONNX
onnx_resnet_search = onnx.load("resnet_search.onnx")
# Check whether the ONNX backbone model has been successfully imported
onnx.checker.check_model(onnx_resnet_search)
print(onnx.checker.check_model(onnx_resnet_search))
onnx.helper.printable_graph(onnx_resnet_search.graph)
print(onnx.helper.printable_graph(onnx_resnet_search.graph))
# Outputs from the torch backbone model --> Inputs to the torch neck model
zf = backbone(torch.Tensor(target))
xf = backbone(torch.Tensor(search))
# Adjustments to the outputs from the torch backbone model to match to inputs to the torch neck model
zf_1 = zf[0].detach().numpy()
zf_2 = zf[1].detach().numpy()
zf_3 = zf[2].detach().numpy()
xf_1 = xf[0].detach().numpy()
xf_2 = xf[1].detach().numpy()
xf_3 = xf[2].detach().numpy()
# Build the torch neck_1 model
neck_1 = AdjustedLayerBuilder_1()
neck_1.eval()
neck_1.state_dict().keys()
neck_1_dict = neck_1.state_dict()
# Load the pre-trained weight to the torch neck_1 model
pretrained_dict_neck_1 = {k: v for k, v in pretrained_dict_neck_1.items() if k in neck_1_dict}
pretrained_dict_neck_1.keys()
neck_1_dict.update(pretrained_dict_neck_1)
neck_1.load_state_dict(neck_1_dict)
# Export the torch neck_1 model to ONNX model
torch.onnx.export(neck_1, [torch.Tensor(np.random.rand(*zf_1.shape)), torch.Tensor(np.random.rand(*zf_2.shape)), torch.Tensor(np.random.rand(*zf_3.shape))], "neck_1.onnx", export_params=True, opset_version=11,
do_constant_folding=True, input_names = ['input_1', 'input_2', 'input_3'], output_names = ['output_1', 'output_2', 'output_3'])
# Load the saved neck_1 model using ONNX
onnx_neck_1_model = onnx.load("neck_1.onnx")
# Check whether the neck_1 model has been successfully imported
onnx.checker.check_model(onnx_neck_1_model)
print(onnx.checker.check_model(onnx_neck_1_model))
onnx.helper.printable_graph(onnx_neck_1_model.graph)
print(onnx.helper.printable_graph(onnx_neck_1_model.graph))
# Build the torch neck_2 model
neck_2 = AdjustedLayerBuilder_2()
neck_2.eval()
neck_2.state_dict().keys()
neck_2_dict = neck_2.state_dict()
# Load the pre-trained weight to the torch neck_2 model
pretrained_dict_neck_2 = {k: v for k, v in pretrained_dict_neck_2.items() if k in neck_2_dict}
pretrained_dict_neck_2.keys()
neck_2_dict.update(pretrained_dict_neck_2)
neck_2.load_state_dict(neck_2_dict)
# Export the torch neck_2 model to ONNX
torch.onnx.export(neck_2, [torch.Tensor(np.random.rand(*xf_1.shape)), torch.Tensor(np.random.rand(*xf_2.shape)), torch.Tensor(np.random.rand(*xf_3.shape))], "neck_2.onnx", export_params=True, opset_version=11,
do_constant_folding=True, input_names = ['input_1', 'input_2', 'input_3'], output_names = ['output_1', 'output_2', 'output_3'])
# Load the saved neck_2 model using ONNX
onnx_neck_2_model = onnx.load("neck_2.onnx")
# Check whether the neck_2 model has been successfully imported
onnx.checker.check_model(onnx_neck_2_model)
print(onnx.checker.check_model(onnx_neck_2_model))
onnx.helper.printable_graph(onnx_neck_2_model.graph)
print(onnx.helper.printable_graph(onnx_neck_2_model.graph))
# Outputs from the torch neck_1 model
zfs_1, zfs_2, zfs_3 = neck_1([torch.Tensor(zf_1), torch.Tensor(zf_2), torch.Tensor(zf_3)])
# Outputs from the torch neck_2 model
xfs_1, xfs_2, xfs_3 = neck_2([torch.Tensor(xf_1), torch.Tensor(xf_2), torch.Tensor(xf_3)])
# Adjustments to the outputs from each of the neck models to match to input shape of the torch rpn_head model
zfs = np.stack([zfs_1.detach().numpy(), zfs_2.detach().numpy(), zfs_3.detach().numpy()])
xfs = np.stack([xfs_1.detach().numpy(), xfs_2.detach().numpy(), xfs_3.detach().numpy()])
# Build the torch rpn_head model
rpn_head = RPNBuilder()
rpn_head.eval()
rpn_head.state_dict().keys()
rpn_head_dict = rpn_head.state_dict()
# Load the pre-trained weights to the rpn_head model
pretrained_dict_head = {k: v for k, v in pretrained_dict_head.items() if k in rpn_head_dict}
pretrained_dict_head.keys()
rpn_head_dict.update(pretrained_dict_head)
rpn_head.load_state_dict(rpn_head_dict)
rpn_head.eval()
# Export the torch rpn_head model to ONNX model
torch.onnx.export(rpn_head, (torch.Tensor(np.random.rand(*zfs.shape)), torch.Tensor(np.random.rand(*xfs.shape))), "rpn_head.onnx", export_params=True, opset_version=11,
do_constant_folding=True, input_names = ['input_1', 'input_2'], output_names = ['output_1', 'output_2'])
# Load the saved rpn_head model using ONNX
onnx_rpn_head_model = onnx.load("rpn_head.onnx")
# Check whether the rpn_head model has been successfully imported
onnx.checker.check_model(onnx_rpn_head_model)
print(onnx.checker.check_model(onnx_rpn_head_model))
onnx.helper.printable_graph(onnx_rpn_head_model.graph)
print(onnx.helper.printable_graph(onnx_rpn_head_model.graph))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment