Skip to content

Instantly share code, notes, and snippets.

@iiSeymour
Created June 27, 2019 14:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save iiSeymour/9c306e53adc0d0c3c92001263caffe5d to your computer and use it in GitHub Desktop.
Save iiSeymour/9c306e53adc0d0c3c92001263caffe5d to your computer and use it in GitHub Desktop.
Symbolic patching for exporting torch.{min,max} with onnx
import torch
import torch.onnx.symbolic
def sym_patch(func, op1, op2):
def replace(g, node, dim_or_y=None, keepdim=None):
ops = func(g, node, dim_or_y, keepdim)
if type(ops) == tuple:
if ops[0].node().kind() == "onnx::ATen":
dim = torch.onnx.symbolic._get_const(dim_or_y, 'i', 'dim')
keepdim = torch.onnx.symbolic._get_const(keepdim, 'i', 'keepdim')
rmax = g.op(op1, node, axes_i=[dim], keepdims_i=keepdim)
indices = g.op(op2, node, axis_i=dim, keepdims_i=keepdim)
return rmax, indices
return ops
return replace
torch.onnx.symbolic.max = sym_patch(torch.onnx.symbolic.max, "ReduceMax", "ArgMax")
torch.onnx.symbolic.min = sym_patch(torch.onnx.symbolic.min, "ReduceMin", "ArgMin")
class MaxModel(torch.nn.Module):
def forward(self, x):
mx = torch.max(x, dim=1)
mn = torch.min(x, dim=1)
return mx, mn
x = torch.randn(4, 4)
model = MaxModel()
torch.onnx.export(model, x, "model.onnx", verbose=True)
@iiSeymour
Copy link
Author

ONNX graph -

graph(%0 : Float(4, 4)):
  %1 : Float(4) = onnx::ReduceMax[axes=[1], keepdims=0](%0), scope: MaxModel
  %2 : Long(4) = onnx::ArgMax[axis=1, keepdims=0](%0), scope: MaxModel
  %3 : Float(4) = onnx::ReduceMin[axes=[1], keepdims=0](%0), scope: MaxModel
  %4 : Long(4) = onnx::ArgMin[axis=1, keepdims=0](%0), scope: MaxModel
  return (%1, %2, %3, %4)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment