Skip to content

Instantly share code, notes, and snippets.

@mihow
Last active Nov 5, 2020
Embed
What would you like to do?
Export PyTorch model to TorchScript / JIT format from fastai learner
import os
import time
import pathlib
import tarfile
import tempfile
import torch
PROJECT_PATH = pathlib.Path(os.environ.get("PROJECT_PATH", "."))
def export_model(model, classes, export_dir=None):
"""
Export model to pickled jit version.
@TODO add file to tarball with metadata about the model.
Version shape of image, who trained, friendly name, etc.
Perhaps include a link to the training page on W&B or CometML and our model zoo.
>>> data = get_data(256, 256, 0.9999)
>>> export_model(model, classes)
"""
timestamp = str(int(time.time()))
export_dir = export_dir or PROJECT_PATH / "models"
export_dir.mkdir(exist_ok=True, parents=True)
model_path = tempfile.NamedTemporaryFile(delete=False).name
classes_path = tempfile.NamedTemporaryFile(delete=False).name
tarball_path = export_dir / f"model-{timestamp}.tar.gz"
# "eval()" is needed to predict a single image? batch size of 1
# https://discuss.pytorch.org/t/error-expected-more-than-1-value-per-channel-when-training/26274
model_raw = (
model.float().eval()
)
model_classes = (
classes # This is just a 1-dimensional list of class names, ordered by index
)
# @TODO not sure if image size should be fixed or always use the current
# model's image size that we are exporting
# I've generally been using 3x256x256
# channels, width, height = learner.data.single_ds[0][0].shape
channels, width, height = 3, 256, 256
print("Example item shape:", channels, width, height)
example_input = torch.ones(1, channels, width, height)
if torch.cuda.is_available():
example_input = example_input.cuda()
# Create jit model
print("Exporting model")
model_jit = torch.jit.trace(model_raw, example_input)
# Save model
torch.jit.save(model_jit, model_path)
# Save list of classes, ordered by index!
with open(classes_path, "w") as f:
for c in model_classes:
f.write(f"{c}\n")
# Create tar archive with the exported model and classes text file
with tarfile.open(tarball_path, "w:gz") as f:
f.add(model_path, arcname="model.pkl") # Don't save directories
f.add(classes_path, arcname="classes.txt")
print("Model and classes saved to", tarball_path)
return tarball_path
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment