Skip to content

Instantly share code, notes, and snippets.

@stealthinu
Last active March 17, 2023 10:17
Show Gist options
  • Save stealthinu/99c4865d6bfe63440c26aa97a45bc579 to your computer and use it in GitHub Desktop.
Save stealthinu/99c4865d6bfe63440c26aa97a45bc579 to your computer and use it in GitHub Desktop.
onnx2tf randn_like issue test
import torch
class OnnxTestRandnLike(torch.nn.Module):
def forward(self, x):
m = torch.randn_like(x)
y = m * x
return y
class OnnxTestRandn(torch.nn.Module):
def forward(self, x):
m = torch.randn(x.size())
y = m * x
return y
def export_onnx(x, net, file):
torch.onnx.export(
net,
x,
file,
opset_version=17,
verbose=False,
input_names=["x"],
output_names=["y"])
print("Done: " + file)
def main():
export_onnx(torch.rand(1, 2, 4), OnnxTestRandnLike(), "randnlike4.onnx")
export_onnx(torch.rand(1, 2, 4), OnnxTestRandn(), "randn4.onnx")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment