Created
July 28, 2020 11:27
-
-
Save jinyup100/cca09da59fb90f98a757f61e59862c7c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import math | |
import numpy as np | |
import os | |
import onnx | |
import torch | |
from torch.autograd import Function | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Function | |
from torch.onnx import register_custom_op_symbolic | |
from torch.onnx.symbolic_helper import parse_args | |
#@staticmethod | |
#def symbolic(g, kernel, search): | |
# return g.op("DepthwiseCorrelation", kernel, search) | |
#@parse_args('v', 'v', 'f', 'i') | |
#def symbolic_depthwise_forward(kernel, search): | |
# batch = kernel.size(0) | |
# channel = kernel.size(1) | |
# search = search.view(1, batch*channel, search.size(2), search.size(3)) | |
# kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3)) | |
# out = F.conv2d(search, kernel, groups=batch*channel) | |
# out = out.view(batch, channel, out.size(2), out.size(3)) | |
# return g.op("Depthwise", out) | |
#from torch.onnx import register_custom_op_symbolic | |
#register_custom_op_symbolic('custom_ops::depthwise_forward', symbolic_depthwise_forward, 11) | |
## Create custom symbolic function | |
#from torch.onnx.symbolic_helper import parse_args | |
#@parse_args('v', 'v', 'f', 'i') | |
#def symbolic_foo_forward(g, input1, input2, attr1, attr2): | |
# return g.op("Foo", input1, input2, attr1_f=attr1, attr2_i=attr2) | |
#Register custom symbolic function | |
from torch.onnx import register_custom_op_symbolic | |
register_custom_op_symbolic('custom_ops::foo_forward', symbolic_foo_forward, 11) | |
@torch.jit.script | |
def symbolic_depthwise_forward(kernel, search): | |
batch = kernel.size(0) | |
channel = kernel.size(1) | |
search = search.view(1, batch*channel, search.size(2), search.size(3)) | |
kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3)) | |
out = F.conv2d(search, kernel, groups=batch*channel) | |
out = out.view(batch, channel, out.size(2), out.size(3)) | |
return out | |
from torch.autograd import Function | |
class DepthwiseCorrelationLayer(Function): | |
@staticmethod | |
def symbolic(g, kernel, search): | |
return g.op("DepthwiseCorrelation", kernel, search) | |
@staticmethod | |
def symbolic_depthwise_forward(self, kernel, search): | |
batch = kernel.size(0) | |
channel = kernel.size(1) | |
search = search.view(1, batch*channel, search.size(2), search.size(3)) | |
kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3)) | |
out = F.conv2d(search, kernel, groups=batch*channel) | |
out = out.view(batch, channel, out.size(2), out.size(3)) | |
return out | |
from torch.onnx import register_custom_op_symbolic | |
register_custom_op_symbolic('custom_ops::depthwise_forward', DepthwiseCorrelationLayer.symbolic_depthwise_forward, 11) | |
torch.ops.custom_ops::depthwise_forward() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment