Skip to content

Instantly share code, notes, and snippets.

@rongtuech
Last active August 14, 2021 09:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rongtuech/2c0b7d500403f5a82560ad547f130941 to your computer and use it in GitHub Desktop.
Save rongtuech/2c0b7d500403f5a82560ad547f130941 to your computer and use it in GitHub Desktop.
import argparse
import torch
from tgcn_model import GCN_muti_att
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-w', '--weight', type=str, required=True)
parser.add_argument("-o", '--output_name', type=str, required=True)
args = parser.parse_args()
num_classes =100
net = GCN_muti_att(input_feature=50 * 2, hidden_feature=64,
num_class=num_classes, p_dropout=0.3, num_stage=20)
net.load_state_dict(torch.load(args.weight))
input = torch.randn(1, 55, 50 * 2)
input_names = ['data']
output_names = ['output']
torch.onnx.export(net, input, args.output_name,
verbose=True,
input_names=input_names,
output_names=output_names)
@rongtuech
Copy link
Author

Copy this code to PoseTGN folder in WLASL to export onnx from pretrained pytorch model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment