Skip to content

Instantly share code, notes, and snippets.

@jinyup100
Last active February 28, 2021 14:23
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jinyup100/9d4cf8f1aac17be75aab000a5e504898 to your computer and use it in GitHub Desktop.
Save jinyup100/9d4cf8f1aac17be75aab000a5e504898 to your computer and use it in GitHub Desktop.

Overview

Mentors : Liubov Batanina @l-bat, Stefano Fabri @bhack, Ilya Elizarov @ieliz
Student : Jin Yeob Chung @jinyup100
Mentors' Project Proposal : https://summerofcode.withgoogle.com/projects/#4979746967912448
Link to Pull Request : opencv/opencv#17647
Link to video summarising the experience : https://www.youtube.com/watch?v=D9G1vHqJCrc

Introduction

Recent interest in computer vision has led to a great advance in the development of visual trackers. Specifically, various applications of Kernelized Correlation Function (KCF) and deep learning have led to numerous implementations of single object trackers using publicly available libraries. Lately, there has been an increased focus on the function of convolutional features in developing visual trackers. In this particular project, I look to focus on the implementations of visual trackers based on deep learning, using the resources available from OpenCV.

Objectives

The main contribution of the research project is in three-fold:

  • Export of the torch implementation of the SiamRPN++ visual tracker to ONNX.
  • Addition of the opencv sample of the SiamRPN++ into the opencv/samples/dnn repository .
    • by importing the ONNX format of the tracker using OpenCV library.
  • Testing the performance of the state-of-the-art visual trackers against the classical trackers.

Initial Proposal

One of the contributions is to implement the state-of-the-art visual trackersand make the designed classes readily available for public usage. In particular, I was intrigued by the visual tracker proposed in “ATOM: Accurate Tracking by Overlap Maximization” by Danellijan et al. The author essentially decomposes the tracking problem into a combination of a classification task and an estimation task. The classification task identifies the foreground and the background of an image frame by locating the region of the target in the frame. The estimation task evaluates the state of the target by creating a bounding box.

There were some problems related to this tracker. Firstly, there were issues regarding the license such that it would be inappropriate to add this tracker to the OpenCV library. Secondly, even if the issues regarding the license are resolved, the tracker required implementation of custom layers that were specific to CUDA.

Adjusted Proposal

Having spent some time considering the issue, the team shared thoughts on switching th focus onto the implementation of SiamRPN++ which seemed to have no particular issues related to license. Over a few days, I spent time reading into this tracker proposed by Li et al. and making additions to the original proposal.

Link to the Original Paper : https://arxiv.org/abs/1812.11703 Link to the Original Repository : https://github.com/STVIR/pysot

Details of visual trackers with Siamese Architecture

The application of Siamese Network has been widely successful for similarity learning, as has been demonstrated by “DeepFace: Closing the gap to human-level performance in face verification”. Relatedly, the application of such architecture in visual tracking was introduced in the paper “Learning a Similarity Metric Discriminatively, with Application to Face Verification” by Copra, Hadsell and LeCun, and there has been an increasing usage of Siamese architecture in the implementation of visual trackers.

A Siamese Network refers to a network with a specific structure, one that consists of identical neural networks that use shared weights. Referring to the paper, given a candidate image x, and an exemplar image z, the aim is to learn a function f(x,z) that compares x and z and return a high score if the two images depict the same object and a low score otherwise. The function f(x,z) outputs the position of the object within a new image by exhaustively eliciting all the possible locations of the object within the new image and measuring similarity between each candidate image with the exemplar image. The measure of similarity is calculated by using a distance function g, which takes the outputs of the neural network φ(x) and φ(z), and calculating the convolution of these feature maps. That is, whilst the output of the neural network φ provides an embedding, a representation for each candidate image x and exemplar image z, the function f(x,z) is a scoring function defined as f(x,z) = g(φ(x), φ(z)), which takes such representations to output a scoring map indicating the value of similarity for corresponding regions. Specifically, given a ConvNet φ with fixed parameters θ, the function is defined as fθ(x,z) = φθ(x) ⋆ φθ(z) +b. The paper takes a discriminative approach and as such, a logistic loss function is defined by the mean of the individual losses shown below, the resulting optimization problem of which is solved using Stochastic Gradient Descent (SGD). The novelty of SiameseRPN++ owes to its ability to incorporate dense and sophisticated deep neural network like ResNet in the Siamese architecture. SiamRPN++ is essentially a Siamese-RPN based tracker using the ResNet as a backbone network to obtain significant performance improvements from the conventional trackers. As noted by previous research on ResNet, its architecture allows us to obtain feature maps at different levels of layers such that they can be aggravated to identify regions with a high level of cross correlation.

Details of the Implementation

Export of the Torch model to ONNX Format

Figure 1 shows the illustration of the proposed framework. Given the template and search region of an image frame of a sample video, the proposed network outputs a dense prediction by fusing the outputs from multiple Siamese Regional Proposal (SiamRPN) blocks. The structure of each of the SiamRPN blocks is shown by the diagram on the right hand side of Figure 1.

Figure 1

With reference to Figure 1, the SiamRPN++ visual tracker was divided into three parts for the actual implementation: the backbone network called ResNet50, the neck layers named the Adjusted layers, and the head network consisting of multiple regional proposal networks which we called the Multi-RPN. The implementation of each part of the visual tracker model in Torch proved to be a relatively easier task with the understanding of the general architecture of the model and the availability of the original repository of the SiamRPN++ model. However, when it came to actually exporting the model in the ONNX format, there were multiple problems associated with it. Firstly, not all the layers implemented using torch were directly exportable to ONNX format. That is, there were torch operators not currently supported by ONNX. This meant that some of the operators had to be re-defined in terms of relatively more elementary functions which were currently supported by ONNX. Secondly, there were problems associated with the conditional functions and loop functions. With regards to the problem related to the former, different ONNX models had to be generated for each of the conditional. Therefore, it was important to rid any unnecessary conditionals and in certain cases, the values had to be hard-coded where necessary. With regards to the problem related to the latter, the export of torch to ONNX format prevented any repetition over the size or shape of a given tensor. More generally, there were also cases where the loop itself was preventing a conversion to ONNX format. In order to resolve these issues, the data structure of some of the variables related to a tensor had to be converted to int format, and in a few cases, the loops had to be coded in a long way such that the conversion was successful. This was especially true for the Multi-RPN which incorporated three regional proposal networks in its implementation.

The code below shows the code that exports the SiamRPN++ into ONNX format. The code effectively divides SiamRPN++ into a search branch, a template branch, The overall tracker is later placed all together by the OpenCV library.

Code to generate ONNX Models The plans for the code that I had written are available here: https://docs.google.com/document/d/1Yu_LndrSxykG_qNl-Z1QEjWqwFss5mjfmNaig-tmEw4/edit

The code shown below to generate the ONNX models of siamrpn++ is also available from :
https://gist.github.com/jinyup100/7aa748686c5e234ed6780154141b4685

The Final Version of the Pre-Trained Weights and successfully converted ONNX format of the models using the codes are available at:

Pre-Trained Weights in pth Format
https://drive.google.com/file/d/11bwgPFVkps9AH2NOD1zBDdpF_tQghAB-/view?usp=sharing

Target Net : Import ✔️ Export ✔️
https://drive.google.com/file/d/1dw_Ne3UMcCnFsaD6xkZepwE4GEpqq7U_/view?usp=sharing

Search Net : Import ✔️ Export ✔️
https://drive.google.com/file/d/1Lt4oE43ZSucJvze3Y-Z87CVDreO-Afwl/view?usp=sharing

RPN_head : Import : ✔️ Export ✔️
https://drive.google.com/file/d/1zT1yu12mtj3JQEkkfKFJWiZ71fJ-dQTi/view?usp=sharing

import numpy as np
import onnx
import torch
import torch.nn as nn

# Class for the Building Blocks required for ResNet
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1,
                 downsample=None, dilation=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        padding = 2 - stride
        if downsample is not None and dilation > 1:
            dilation = dilation // 2
            padding = dilation

        assert stride == 1 or dilation == 1, \
            "stride and dilation must have one equals to zero at least"

        if dilation > 1:
            padding = dilation
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=padding, bias=False, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        out = self.relu(out)

        return out
    
# End of Building Blocks

# Class for ResNet - the Backbone neural network

class ResNet(nn.Module):
    "ResNET"
    def __init__(self, block, layers, used_layers):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0,  # 3
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)

        self.feature_size = 128 * block.expansion
        self.used_layers = used_layers
        layer3 = True if 3 in used_layers else False
        layer4 = True if 4 in used_layers else False

        if layer3:
            self.layer3 = self._make_layer(block, 256, layers[2],
                                           stride=1, dilation=2)  # 15x15, 7x7
            self.feature_size = (256 + 128) * block.expansion
        else:
            self.layer3 = lambda x: x  # identity

        if layer4:
            self.layer4 = self._make_layer(block, 512, layers[3],
                                           stride=1, dilation=4)  # 7x7, 3x3
            self.feature_size = 512 * block.expansion
        else:
            self.layer4 = lambda x: x  # identity

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, np.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        dd = dilation
        if stride != 1 or self.inplanes != planes * block.expansion:
            if stride == 1 and dilation == 1:
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes * block.expansion,
                              kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(planes * block.expansion),
                )
            else:
                if dilation > 1:
                    dd = dilation // 2
                    padding = dd
                else:
                    dd = 1
                    padding = 0
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes * block.expansion,
                              kernel_size=3, stride=stride, bias=False,
                              padding=padding, dilation=dd),
                    nn.BatchNorm2d(planes * block.expansion),
                )

        layers = []
        layers.append(block(self.inplanes, planes, stride,
                            downsample, dilation=dilation))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x_ = self.relu(x)
        x = self.maxpool(x_)

        p1 = self.layer1(x)
        p2 = self.layer2(p1)
        p3 = self.layer3(p2)
        p4 = self.layer4(p3)
        out = [x_, p1, p2, p3, p4]
        out = [out[i] for i in self.used_layers]
        if len(out) == 1:
            return out[0]
        else:
            return out
        
# End of ResNet

# Class for Adjusting the layers of the neural net

class AdjustLayer_1(nn.Module):
    def __init__(self, in_channels, out_channels, center_size=7):
        super(AdjustLayer_1, self).__init__()
        self.downsample = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            )
        self.center_size = center_size

    def forward(self, x):
        x = self.downsample(x)
        l = 4
        r = 11
        x = x[:, :, l:r, l:r]
        return x

class AdjustAllLayer_1(nn.Module):
    def __init__(self, in_channels, out_channels, center_size=7):
        super(AdjustAllLayer_1, self).__init__()
        self.num = len(out_channels)
        if self.num == 1:
            self.downsample = AdjustLayer_1(in_channels[0],
                                          out_channels[0],
                                          center_size)
        else:
            for i in range(self.num):
                self.add_module('downsample'+str(i+2),
                                AdjustLayer_1(in_channels[i],
                                            out_channels[i],
                                            center_size))

    def forward(self, features):
        if self.num == 1:
            return self.downsample(features)
        else:
            out = []
            for i in range(self.num):
                adj_layer = getattr(self, 'downsample'+str(i+2))
                out.append(adj_layer(features[i]))
            return out
        
class AdjustLayer_2(nn.Module):
    def __init__(self, in_channels, out_channels, center_size=7):
        super(AdjustLayer_2, self).__init__()
        self.downsample = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            )
        self.center_size = center_size

    def forward(self, x):
        x = self.downsample(x)
        return x

class AdjustAllLayer_2(nn.Module):
    def __init__(self, in_channels, out_channels, center_size=7):
        super(AdjustAllLayer_2, self).__init__()
        self.num = len(out_channels)
        if self.num == 1:
            self.downsample = AdjustLayer_2(in_channels[0],
                                          out_channels[0],
                                          center_size)
        else:
            for i in range(self.num):
                self.add_module('downsample'+str(i+2),
                                AdjustLayer_2(in_channels[i],
                                            out_channels[i],
                                            center_size))

    def forward(self, features):
        if self.num == 1:
            return self.downsample(features)
        else:
            out = []
            for i in range(self.num):
                adj_layer = getattr(self, 'downsample'+str(i+2))
                out.append(adj_layer(features[i]))
            return out
        
# End of Class for Adjusting the layers of the neural net

# Class for Region Proposal Neural Network

class RPN(nn.Module):
    "Region Proposal Network"
    def __init__(self):
        super(RPN, self).__init__()

    def forward(self, z_f, x_f):
        raise NotImplementedError
        
class DepthwiseXCorr(nn.Module):
    "Depthwise Correlation Layer"
    def __init__(self, in_channels, hidden, out_channels, kernel_size=3, hidden_kernel_size=5):
        super(DepthwiseXCorr, self).__init__()
        self.conv_kernel = nn.Sequential(
                nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False),
                nn.BatchNorm2d(hidden),
                nn.ReLU(inplace=True),
                )
        self.conv_search = nn.Sequential(
                nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False),
                nn.BatchNorm2d(hidden),
                nn.ReLU(inplace=True),
                )
        self.head = nn.Sequential(
                nn.Conv2d(hidden, hidden, kernel_size=1, bias=False),
                nn.BatchNorm2d(hidden),
                nn.ReLU(inplace=True),
                nn.Conv2d(hidden, out_channels, kernel_size=1)
                )
        
    def forward(self, kernel, search):    
        kernel = self.conv_kernel(kernel)
        search = self.conv_search(search)
        
        feature = xcorr_depthwise(search, kernel)
        
        out = self.head(feature)
        
        return out

class DepthwiseRPN(RPN):
    def __init__(self, anchor_num=5, in_channels=256, out_channels=256):
        super(DepthwiseRPN, self).__init__()
        self.cls = DepthwiseXCorr(in_channels, out_channels, 2 * anchor_num)
        self.loc = DepthwiseXCorr(in_channels, out_channels, 4 * anchor_num)

    def forward(self, z_f, x_f):
        cls = self.cls(z_f, x_f)
        loc = self.loc(z_f, x_f)
        
        return cls, loc

class MultiRPN(RPN):
    def __init__(self, anchor_num, in_channels):
        super(MultiRPN, self).__init__()
        for i in range(len(in_channels)):
            self.add_module('rpn'+str(i+2),
                    DepthwiseRPN(anchor_num, in_channels[i], in_channels[i]))
        self.weight_cls = nn.Parameter(torch.Tensor([0.38156851768108546, 0.4364767608115956,  0.18195472150731892]))
        self.weight_loc = nn.Parameter(torch.Tensor([0.17644893463361863, 0.16564198028417967, 0.6579090850822015]))

    def forward(self, z_fs, x_fs):
        cls = []
        loc = []
        
        rpn2 = self.rpn2
        z_f2 = z_fs[0]
        x_f2 = x_fs[0]
        c2,l2 = rpn2(z_f2, x_f2)
        
        cls.append(c2)
        loc.append(l2)
        
        rpn3 = self.rpn3
        z_f3 = z_fs[1]
        x_f3 = x_fs[1]
        c3,l3 = rpn3(z_f3, x_f3)
        
        cls.append(c3)
        loc.append(l3)
        
        rpn4 = self.rpn4
        z_f4 = z_fs[2]
        x_f4 = x_fs[2]
        c4,l4 = rpn4(z_f4, x_f4)
        
        cls.append(c4)
        loc.append(l4)
        
        def avg(lst):
            return sum(lst) / len(lst)

        def weighted_avg(lst, weight):
            s = 0
            fixed_len = 3
            for i in range(3):
                s += lst[i] * weight[i]
            return s

        return weighted_avg(cls, self.weight_cls), weighted_avg(loc, self.weight_loc)
        
# End of class for RPN

def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, bias=False, dilation=dilation)

def xcorr_depthwise(x, kernel):
    """
    Deptwise convolution for input and weights with different shapes
    """
    batch = kernel.size(0)
    channel = kernel.size(1)
    x = x.view(1, batch*channel, x.size(2), x.size(3))
    kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3))
    conv = nn.Conv2d(batch*channel, batch*channel, kernel_size=(kernel.size(2), kernel.size(3)), bias=False, groups=batch*channel)
    conv.weight = nn.Parameter(kernel)
    out = conv(x) 
    out = out.view(batch, channel, out.size(2), out.size(3))
    out = out.detach()
    return out

class TargetNetBuilder(nn.Module):
    def __init__(self):
        super(TargetNetBuilder, self).__init__()
        # Build Backbone Model
        self.backbone = ResNet(Bottleneck, [3,4,6,3], [2,3,4])
        # Build Neck Model
        self.neck = AdjustAllLayer_1([512,1024,2048], [256,256,256])
    
    def forward(self, frame):
        features = self.backbone(frame)
        output = self.neck(features)
        return output

class SearchNetBuilder(nn.Module):
    def __init__(self):
        super(SearchNetBuilder, self).__init__()
        # Build Backbone Model
        self.backbone = ResNet(Bottleneck, [3,4,6,3], [2,3,4])
        # Build Neck Model
        self.neck = AdjustAllLayer_2([512,1024,2048], [256,256,256])
        
    def forward(self, frame):
        features = self.backbone(frame)
        output = self.neck(features)
        return output
 
class RPNBuilder(nn.Module):
    def __init__(self):
        super(RPNBuilder, self).__init__()

        # Build Adjusted Layer Builder
        self.rpn_head = MultiRPN(anchor_num=5,in_channels=[256, 256, 256])

    def forward(self, zf, xf):
        # Get Feature
        cls, loc = self.rpn_head(zf, xf)

        return cls, loc
    
"""Load path should be the directory of the pre-trained siamrpn_r50_l234_dwxcorr.pth
 The download link to siamrpn_r50_l234_dwxcorr.pth is shown in the description"""

current_path = os.getcwd()
load_path = os.path.join(current_path, "siamrpn_r50_l234_dwxcorr.pth")
pretrained_dict = torch.load(load_path,map_location=torch.device('cpu') )
pretrained_dict_backbone = pretrained_dict
pretrained_dict_neck_1 = pretrained_dict
pretrained_dict_neck_2 = pretrained_dict
pretrained_dict_head = pretrained_dict
pretrained_dict_target = pretrained_dict
pretrained_dict_search = pretrained_dict

# The shape of the inputs to the Target Network and the Search Network
target = torch.Tensor(np.random.rand(1,3,127,127))
search = torch.Tensor(np.random.rand(1,3,255,255))

# Build the torch backbone model
target_net = TargetNetBuilder()
target_net.eval()
target_net.state_dict().keys()
target_net_dict = target_net.state_dict()

# Load the pre-trained weight to the torch target net model
pretrained_dict_target = {k: v for k, v in pretrained_dict_target.items() if k in target_net_dict}
target_net_dict.update(pretrained_dict_target)
target_net.load_state_dict(target_net_dict)

# Export the torch target net model to ONNX model
torch.onnx.export(target_net, torch.Tensor(target), "target_net.onnx", export_params=True, opset_version=11,
                  do_constant_folding=True, input_names=['input'], output_names=['output_1,', 'output_2', 'output_3'])

# Load the saved torch target net model using ONNX
onnx_target = onnx.load("target_net.onnx")

# Check whether the ONNX target net model has been successfully imported
onnx.checker.check_model(onnx_target)
print(onnx.checker.check_model(onnx_target))
onnx.helper.printable_graph(onnx_target.graph)
print(onnx.helper.printable_graph(onnx_target.graph))

# Build the torch backbone model
search_net = SearchNetBuilder()
search_net.eval()
search_net.state_dict().keys()
search_net_dict = search_net.state_dict()

# Load the pre-trained weight to the torch target net model
pretrained_dict_search = {k: v for k, v in pretrained_dict_search.items() if k in search_net_dict}
search_net_dict.update(pretrained_dict_search)
search_net.load_state_dict(search_net_dict)

# Export the torch target net model to ONNX model
torch.onnx.export(search_net, torch.Tensor(search), "search_net.onnx", export_params=True, opset_version=11,
                  do_constant_folding=True, input_names=['input'], output_names=['output_1,', 'output_2', 'output_3'])

# Load the saved torch target net model using ONNX
onnx_search = onnx.load("search_net.onnx")

# Check whether the ONNX target net model has been successfully imported
onnx.checker.check_model(onnx_search)
print(onnx.checker.check_model(onnx_search))
onnx.helper.printable_graph(onnx_search.graph)
print(onnx.helper.printable_graph(onnx_search.graph))

# Outputs from the Target Net and Search Net
zfs_1, zfs_2, zfs_3 = target_net(torch.Tensor(target))
xfs_1, xfs_2, xfs_3 = search_net(torch.Tensor(search))

# Adjustments to the outputs from each of the neck models to match to input shape of the torch rpn_head model
zfs = np.stack([zfs_1.detach().numpy(), zfs_2.detach().numpy(), zfs_3.detach().numpy()])
xfs = np.stack([xfs_1.detach().numpy(), xfs_2.detach().numpy(), xfs_3.detach().numpy()])

# Build the torch rpn_head model
rpn_head = RPNBuilder()
rpn_head.eval()
rpn_head.state_dict().keys()
rpn_head_dict = rpn_head.state_dict()

# Load the pre-trained weights to the rpn_head model
pretrained_dict_head = {k: v for k, v in pretrained_dict_head.items() if k in rpn_head_dict}
pretrained_dict_head.keys()
rpn_head_dict.update(pretrained_dict_head)
rpn_head.load_state_dict(rpn_head_dict)
rpn_head.eval()

# Export the torch rpn_head model to ONNX model
torch.onnx.export(rpn_head, (torch.Tensor(np.random.rand(*zfs.shape)), torch.Tensor(np.random.rand(*xfs.shape))), "rpn_head.onnx", export_params=True, opset_version=11,
                  do_constant_folding=True, input_names = ['input_1', 'input_2'], output_names = ['output_1', 'output_2'])

# Load the saved rpn_head model using ONNX
onnx_rpn_head_model = onnx.load("rpn_head.onnx")

# Check whether the rpn_head model has been successfully imported
onnx.checker.check_model(onnx_rpn_head_model)
print(onnx.checker.check_model(onnx_rpn_head_model))    
onnx.helper.printable_graph(onnx_rpn_head_model.graph)
print(onnx.helper.printable_graph(onnx_rpn_head_model.graph))

Import of the ONNX model using OpenCV library

Having successfully exported each part of the visual tracker model to ONNX, the team attempted to import the ONNX models using the OpenCV library. In its DNN module, OpenCV contains ONNX Importer, which supports the import of the ONNX file to be readily used by the library. In the implementation of ONNX Importer, the ONNX Graph Simplifier is called before-hand, such that the function simplifies the graph of the ONNX file such that it is compatible with the OpenCV DNN modules. When importing the converted ONNX file of the visual tracker model, the team faced two major problems. Firstly, when importing the Adjusted convolutional layer, the ONNX Importer did not accommodate for the Gather layer with variable multiple inputs. Secondly, when importing the Multi-RPN model, the ONNX Importer at the stage did not support inconstant weights for convolution. Specifically, the convolution did not support variable kernel or bias for the convolution.

{{Figure 2}}

To further elaborate on the latter problem, the implementation of SiamRPN++ required support for convolution with inconstant weights because of the presence of the Depth-wise Cross Correlation (DW-XCorr) layer, which is a feat specific to this particular type of visual tracker. With reference to the paper, the cross correlation is an operation to embed information coming from the two branches, the search branch and the target branch. Relatedly, SiamFC employs a Cross Correlation layer to obtain a single channel response map as shown in Figure 2. SiamRPN++ performs DW-XCorr by obtaining adjusted feature maps from each of the two branches. The two feature maps each respectively from the search branch and the target branch, with the same number of channels, perform correlation operation channel by channel. It is this part of the calculation process that required the convolution with varying weights. This problem was the area with which I had the most difficulty as to how the problem could be solved. I am grateful for the work of one of the mentors, Liubov Batanina and it was only by having her pull request merged that I was successfully able to import the visual tracker model using the OpenCV library.

The successful export and import of the model finally meant that I was able to complete my pull request, which was essentially about having the SiamRPN++ tracker sample in the OpenCV repository. With the revisions of the codes by the mentors, I was able to contribute to the library by adding the siamrpnpp.py file, the details of which can be found in this pull request. Videos of visual tracking with SiamRPN++ is shown by Figure 4. For more sample videos, please refer to this link.

@JosonChan1998
Copy link

Is the input of search network (3,255,255)?

@jinyup100
Copy link
Author

jinyup100 commented Jan 3, 2021

Is the input of search network (3,255,255)?

@JosonChan1998
Yes it should be : search = torch.Tensor(np.random.rand(1,3,255,255))
I have made the fix

@JosonChan1998
Copy link

Ok,Thanks for you great jobs.
Any plan about surportting the mobileNetv2 backbone?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment