Skip to content

Instantly share code, notes, and snippets.

@Norod
Created June 1, 2021 15:01
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Norod/610ee5e70791c2cecee8980c31711764 to your computer and use it in GitHub Desktop.
Save Norod/610ee5e70791c2cecee8980c31711764 to your computer and use it in GitHub Desktop.
Rename a node in an ONNX model
import onnx
onnx_model = onnx.load('./input.onnx')
#Rename 'inp' to 'inst'
endpoint_names = ['inp', 'inst']
for i in range(len(onnx_model.graph.node)):
for j in range(len(onnx_model.graph.node[i].input)):
if onnx_model.graph.node[i].input[j] == endpoint_names[0]:
print('-'*60)
print(onnx_model.graph.node[i].name)
print(onnx_model.graph.node[i].input)
print(onnx_model.graph.node[i].output)
onnx_model.graph.node[i].input[j] = endpoint_names[1]
for j in range(len(onnx_model.graph.node[i].output)):
if onnx_model.graph.node[i].output[j] == endpoint_names[0]:
print('-'*60)
print(onnx_model.graph.node[i].name)
print(onnx_model.graph.node[i].input)
print(onnx_model.graph.node[i].output)
onnx_model.graph.node[i].output[j] = endpoint_names[1]
for i in range(len(onnx_model.graph.input)):
if onnx_model.graph.input[i].name == endpoint_names[0]:
print('-'*60)
print(onnx_model.graph.input[i])
onnx_model.graph.input[i].name = endpoint_names[1]
for i in range(len(onnx_model.graph.output)):
if onnx_model.graph.output[i].name == endpoint_names[0]:
print('-'*60)
print(onnx_model.graph.output[i])
onnx_model.graph.output[i].name = endpoint_names[1]
onnx.save(onnx_model, 'model_mod.onnx')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment