Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bshambaugh/af01e7366ac7f43abd8295533703d990 to your computer and use it in GitHub Desktop.
Save bshambaugh/af01e7366ac7f43abd8295533703d990 to your computer and use it in GitHub Desktop.
likely shape for ONNX for RNN
def main():
torch_model = Circuit()
# Input to the model
shape = [3, 2, 2] /// change to [batch_size, seq_len, input_size]
x = 0.1*torch.rand(1,*shape, requires_grad=True)
y = 0.1*torch.rand(1,*shape, requires_grad=True)
torch_out = torch_model(x, y)
# Export the model
torch.onnx.export(torch_model, # model being run
(x,y), # model input (or a tuple for multiple inputs)
"network.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
d = ((x).detach().numpy()).reshape([-1]).tolist()
dy = ((y).detach().numpy()).reshape([-1]).tolist()
data = dict(input_shapes = [shape, shape],
input_data = [d, dy],
output_data = [((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])
# Serialize data into file:
json.dump( data, open( "input.json", 'w' ) )
if __name__ == "__main__":
main()
@bshambaugh
Copy link
Author

x, y, d, dy might be a little different because the shapes might be different.

you might need:

shape_input = [3, 2, 2] /// change to [batch_size, seq_len, input_size]
shape_hidden_layer = [3, 2, 2] /// change to [batch_size, seq_len, input_size]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment