Skip to content

Instantly share code, notes, and snippets.

@jinyup100
Created July 28, 2020 11:27
Show Gist options
  • Save jinyup100/cca09da59fb90f98a757f61e59862c7c to your computer and use it in GitHub Desktop.
Save jinyup100/cca09da59fb90f98a757f61e59862c7c to your computer and use it in GitHub Desktop.
import cv2
import math
import numpy as np
import os
import onnx
import torch
from torch.autograd import Function
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args
#@staticmethod
#def symbolic(g, kernel, search):
# return g.op("DepthwiseCorrelation", kernel, search)
#@parse_args('v', 'v', 'f', 'i')
#def symbolic_depthwise_forward(kernel, search):
# batch = kernel.size(0)
# channel = kernel.size(1)
# search = search.view(1, batch*channel, search.size(2), search.size(3))
# kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3))
# out = F.conv2d(search, kernel, groups=batch*channel)
# out = out.view(batch, channel, out.size(2), out.size(3))
# return g.op("Depthwise", out)
#from torch.onnx import register_custom_op_symbolic
#register_custom_op_symbolic('custom_ops::depthwise_forward', symbolic_depthwise_forward, 11)
## Create custom symbolic function
#from torch.onnx.symbolic_helper import parse_args
#@parse_args('v', 'v', 'f', 'i')
#def symbolic_foo_forward(g, input1, input2, attr1, attr2):
# return g.op("Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)
#Register custom symbolic function
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic('custom_ops::foo_forward', symbolic_foo_forward, 11)
@torch.jit.script
def symbolic_depthwise_forward(kernel, search):
batch = kernel.size(0)
channel = kernel.size(1)
search = search.view(1, batch*channel, search.size(2), search.size(3))
kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3))
out = F.conv2d(search, kernel, groups=batch*channel)
out = out.view(batch, channel, out.size(2), out.size(3))
return out
from torch.autograd import Function
class DepthwiseCorrelationLayer(Function):
@staticmethod
def symbolic(g, kernel, search):
return g.op("DepthwiseCorrelation", kernel, search)
@staticmethod
def symbolic_depthwise_forward(self, kernel, search):
batch = kernel.size(0)
channel = kernel.size(1)
search = search.view(1, batch*channel, search.size(2), search.size(3))
kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3))
out = F.conv2d(search, kernel, groups=batch*channel)
out = out.view(batch, channel, out.size(2), out.size(3))
return out
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic('custom_ops::depthwise_forward', DepthwiseCorrelationLayer.symbolic_depthwise_forward, 11)
torch.ops.custom_ops::depthwise_forward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment