Skip to content

Instantly share code, notes, and snippets.

@Icemole
Created June 29, 2023 09:23
Show Gist options
  • Save Icemole/47dc9757678ed5bff36cc222627525de to your computer and use it in GitHub Desktop.
Save Icemole/47dc9757678ed5bff36cc222627525de to your computer and use it in GitHub Desktop.
import torch
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, d1, d2, d3):
shape = [d1, d2, d3]
"""
Option 1: convert through tensor.long()
No warning
"""
shape = [dim.long() for dim in shape]
"""
Option 2: convert through torch.tensor(*, dtype=torch.int64)
TracerWarning: torch.tensor results are registered as constants in the trace.
"""
# shape = [torch.tensor(dim, dtype=torch.int64) for dim in shape]
"""
Option 3: don't convert, keep int32
UserWarning: The exported ONNX model failed ONNX shape inference.
"""
# <no code>
return torch.full(shape, 2)
dummy_model = DummyModel()
d1 = torch.tensor(2, dtype=torch.int32)
d2 = torch.tensor(2, dtype=torch.int32)
d3 = torch.tensor(2, dtype=torch.int32)
torch.onnx.export(
dummy_model,
(d1, d2, d3),
f="my_filename.onnx",
verbose=True,
input_names=["d1", "d2", "d3"],
output_names=["casted_data"],
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment