Skip to content

Instantly share code, notes, and snippets.

Last active January 23, 2022 03:45
Show Gist options
  • Save mihow/c22adf52fcb9c07ce67dcf4d2495cedd to your computer and use it in GitHub Desktop.
Save mihow/c22adf52fcb9c07ce67dcf4d2495cedd to your computer and use it in GitHub Desktop.
PyTorch model to JIT / TorchScript
import os
import time
import pathlib
import tarfile
import tempfile
import torch
PROJECT_PATH = pathlib.Path(os.environ.get("PROJECT_PATH", "."))
def export_model(model, classes, export_dir=None):
Export model to pickled jit version.
@TODO add file to tarball with metadata about the model.
Version shape of image, who trained, friendly name, etc.
Perhaps include a link to the training page on W&B or CometML and our model zoo.
>>> data = get_data(256, 256, 0.9999)
>>> export_model(model, classes)
timestamp = str(int(time.time()))
export_dir = export_dir or PROJECT_PATH / "models"
export_dir.mkdir(exist_ok=True, parents=True)
model_path = tempfile.NamedTemporaryFile(delete=False).name
classes_path = tempfile.NamedTemporaryFile(delete=False).name
tarball_path = export_dir / f"model-{timestamp}.tar.gz"
# "eval()" is needed to predict a single image? batch size of 1
model_raw = (
model_classes = (
classes # This is just a 1-dimensional list of class names, ordered by index
# @TODO not sure if image size should be fixed or always use the current
# model's image size that we are exporting
# I've generally been using 3x256x256
# channels, width, height =[0][0].shape
channels, width, height = 3, 256, 256
print("Example item shape:", channels, width, height)
example_input = torch.ones(1, channels, width, height)
if torch.cuda.is_available():
example_input = example_input.cuda()
# Create jit model
print("Exporting model")
model_jit = torch.jit.trace(model_raw, example_input)
# Save model, model_path)
# Save list of classes, ordered by index!
with open(classes_path, "w") as f:
for c in model_classes:
# Create tar archive with the exported model and classes text file
with, "w:gz") as f:
f.add(model_path, arcname="model.pkl") # Don't save directories
f.add(classes_path, arcname="classes.txt")
print("Model and classes saved to", tarball_path)
return tarball_path
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment