Skip to content

Instantly share code, notes, and snippets.

@daigo0927
Last active January 5, 2024 07:29
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save daigo0927/8c8b3005cffb61983e80ceab6c1f2274 to your computer and use it in GitHub Desktop.
Save daigo0927/8c8b3005cffb61983e80ceab6c1f2274 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
import torch.nn.functional as F
import onnx
import onnxruntime as ort
from torch.onnx import register_custom_op_symbolic
import torch.onnx.symbolic_helper as sym_help
# symbolic function makes aten::grid_sampler correspond to ONNX contrib operator
# from https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/pytorch_export_contrib_ops.py
def grid_sampler(g, input, grid, mode, padding_mode, align_corners):
# mode
# 'bilinear' : onnx::Constant[value={0}]
# 'nearest' : onnx::Constant[value={1}]
# 'bicubic' : onnx::Constant[value={2}]
# padding_mode
# 'zeros' : onnx::Constant[value={0}]
# 'border' : onnx::Constant[value={1}]
# 'reflection' : onnx::Constant[value={2}]
mode = sym_help._maybe_get_const(mode, "i")
padding_mode = sym_help._maybe_get_const(padding_mode, "i")
mode_str = ['bilinear', 'nearest', 'bicubic'][mode]
padding_mode_str = ['zeros', 'border', 'reflection'][padding_mode]
align_corners = int(sym_help._maybe_get_const(align_corners, "b"))
# From opset v13 onward, the output shape can be specified with
# (N, C, H, W) (N, H_out, W_out, 2) => (N, C, H_out, W_out)
# input_shape = input.type().sizes()
# grid_shape = grid.type().sizes()
# output_shape = input_shape[:2] + grid_shape[1:3]
# g.op(...).setType(input.type().with_sizes(output_shape))
return g.op("com.microsoft::GridSample", input, grid,
mode_s=mode_str,
padding_mode_s=padding_mode_str,
align_corners_i=align_corners)
register_custom_op_symbolic('::grid_sampler', grid_sampler, 1)
x = torch.randn(1, 4, 10, 10)
grid = 2*torch.rand(1, 8, 8, 2) - 1 # scale as (-1, 1)
ref = F.grid_sample(x, grid, align_corners=False)
ref = ref.detach().numpy()
class Sampler(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, grid):
return F.grid_sample(x, grid, align_corners=False)
torch.onnx.export(
Sampler(),
(x, grid),
'sampler.onnx',
verbose=True,
input_names=['input', 'grid'],
output_names=['output'],
)
sess = ort.InferenceSession('sampler.onnx')
outputs = sess.run(None, {'input': x.numpy(), 'grid': grid.numpy()})
out_onnx = outputs[0]
for o in sess.get_outputs():
print(o.name, o.shape)
# Compare F.grid_sample output with converted ORT output
def max_absolute_percentage_error(ref, out):
return np.abs((ref - out)/ref).max()
print(max_absolute_percentage_error(ref, out_onnx))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment