Last active
January 5, 2024 07:29
-
-
Save daigo0927/8c8b3005cffb61983e80ceab6c1f2274 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import onnx | |
import onnxruntime as ort | |
from torch.onnx import register_custom_op_symbolic | |
import torch.onnx.symbolic_helper as sym_help | |
# symbolic function makes aten::grid_sampler correspond to ONNX contrib operator | |
# from https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/pytorch_export_contrib_ops.py | |
def grid_sampler(g, input, grid, mode, padding_mode, align_corners): | |
# mode | |
# 'bilinear' : onnx::Constant[value={0}] | |
# 'nearest' : onnx::Constant[value={1}] | |
# 'bicubic' : onnx::Constant[value={2}] | |
# padding_mode | |
# 'zeros' : onnx::Constant[value={0}] | |
# 'border' : onnx::Constant[value={1}] | |
# 'reflection' : onnx::Constant[value={2}] | |
mode = sym_help._maybe_get_const(mode, "i") | |
padding_mode = sym_help._maybe_get_const(padding_mode, "i") | |
mode_str = ['bilinear', 'nearest', 'bicubic'][mode] | |
padding_mode_str = ['zeros', 'border', 'reflection'][padding_mode] | |
align_corners = int(sym_help._maybe_get_const(align_corners, "b")) | |
# From opset v13 onward, the output shape can be specified with | |
# (N, C, H, W) (N, H_out, W_out, 2) => (N, C, H_out, W_out) | |
# input_shape = input.type().sizes() | |
# grid_shape = grid.type().sizes() | |
# output_shape = input_shape[:2] + grid_shape[1:3] | |
# g.op(...).setType(input.type().with_sizes(output_shape)) | |
return g.op("com.microsoft::GridSample", input, grid, | |
mode_s=mode_str, | |
padding_mode_s=padding_mode_str, | |
align_corners_i=align_corners) | |
register_custom_op_symbolic('::grid_sampler', grid_sampler, 1) | |
x = torch.randn(1, 4, 10, 10) | |
grid = 2*torch.rand(1, 8, 8, 2) - 1 # scale as (-1, 1) | |
ref = F.grid_sample(x, grid, align_corners=False) | |
ref = ref.detach().numpy() | |
class Sampler(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x, grid): | |
return F.grid_sample(x, grid, align_corners=False) | |
torch.onnx.export( | |
Sampler(), | |
(x, grid), | |
'sampler.onnx', | |
verbose=True, | |
input_names=['input', 'grid'], | |
output_names=['output'], | |
) | |
sess = ort.InferenceSession('sampler.onnx') | |
outputs = sess.run(None, {'input': x.numpy(), 'grid': grid.numpy()}) | |
out_onnx = outputs[0] | |
for o in sess.get_outputs(): | |
print(o.name, o.shape) | |
# Compare F.grid_sample output with converted ORT output | |
def max_absolute_percentage_error(ref, out): | |
return np.abs((ref - out)/ref).max() | |
print(max_absolute_percentage_error(ref, out_onnx)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment