Skip to content

Instantly share code, notes, and snippets.

@piercelamb
Created December 19, 2022 23:03
Show Gist options
  • Save piercelamb/57e3d2b40ba97672c4c43246f2d20238 to your computer and use it in GitHub Desktop.
Save piercelamb/57e3d2b40ba97672c4c43246f2d20238 to your computer and use it in GitHub Desktop.
convert_to_torchscript
def convert_to_torchscript(model_path, best_model, train_data, hyperparams):
cpu_model = best_model.cpu()
cpu_model.eval()
sample_instance = train_data[0]
ordered_input_keys = ordered_model_input_keys()
example_inputs = []
if not isinstance(ordered_input_keys, OrderedDict):
ordered_input_keys = ordered_input_keys[hyperparams['model_name']]
sample_instance = get_model_specific_batch(sample_instance, hyperparams['model_name'])
for idx, key in ordered_input_keys.items():
input = sample_instance[key].unsqueeze(0)
example_inputs.append(input.cpu())
traced_cpu = torch.jit.trace(
func=cpu_model,
example_inputs=tuple(example_inputs),
strict=False, # allows dicts to be used as outputs
check_trace=False # when traced model is checked, an error is produced due to name mangling
)
torch.jit.save(traced_cpu, model_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment