Skip to content

Instantly share code, notes, and snippets.

@spezold
Last active June 16, 2021 11:42
Show Gist options
  • Save spezold/ab49e8f396853b2ff18ef7623ee6c55b to your computer and use it in GitHub Desktop.
Save spezold/ab49e8f396853b2ff18ef7623ee6c55b to your computer and use it in GitHub Desktop.
**Update: have a look at torch.package instead** (https://pytorch.org/docs/1.9.0/package.html) -- Original description: Save and load a PyTorch model (both code and weights): minimum working example, based on the inner workings of TorchServe.
import importlib.util
import inspect
import json
from pathlib import Path
from typing import Optional, Union
import zipfile
import torch
from torch import nn
def _name_from(path: Optional[Union[str, Path]]) -> Optional[str]: return None if path is None else Path(path).name
def _pymodule_from(path: Path) -> object:
"""Load and return the python module from the given ``*.py`` file path."""
# Following https://stackoverflow.com/questions/67631/ (20201008)
spec = importlib.util.spec_from_file_location(name=path.stem, location=path)
pymodule = importlib.util.module_from_spec(spec)
spec.loader.exec_module(pymodule)
return pymodule
def _model_class_from(pymodule: object, name: Optional[str]) -> object:
"""
Load and return the model class (``torch.nn.Module`` subclass) with the given name from the given module (name is
not necessary if there is only one ``torch.nn.Module`` subclass in the given module).
"""
# Following TorchServe: ``ts.utils.util.list_classes_from_module()`` (20201008)
predicate = lambda m: (inspect.isclass(m) and m.__module__ == pymodule.__name__ and issubclass(m, nn.Module) and
(name is None or m.__name__ == name))
classes = [c[1] for c in inspect.getmembers(pymodule, predicate)]
if not classes:
raise ValueError("No ``torch.nn.Module`` subclass " + ("" if name is None else f"named '{name}' ") +
f"found in given module (module '{pymodule.__name__}' in '{pymodule.__file__}')")
elif len(classes) > 1:
raise ValueError(f"Multiple subclasses of ``torch.nn.Module`` found in given module "
f"(module '{pymodule.__name__}' in '{pymodule.__file__}'; "
f"candidates are {', '.join(c.__name__ for c in classes)})")
return classes[0]
def save(*,
name: str,
version: str,
code_path: Path,
archive_path: Path,
class_name: Optional[str] = None,
weights_path: Optional[Path] = None
):
"""
Save the model (both code and, if provided, weights) to a zip archive (similar to but not the same as TorchServe's
``*.mar`` archive).
:param name: name of the model
:param version: version of the model
:param code_path: ``*.py`` file that contains the model class as a ``torch.nn.Module`` subclass
:param archive_path: path to which to write the model's zip archive
:param class_name: optional name of the model class (necessary if there are multiple ``torch.nn.Module`` subclasses
in the code file)
:param weights_path: optional model weights, need to be loadable via
``torch.nn.Module.load_state_dict(torch.load(...))``
"""
manifest_data = { # "name" and "version" are only informative in this example
"name": name,
"version": version,
"code_file": _name_from(code_path),
"weights_file": _name_from(weights_path),
"class_name": class_name
}
with zipfile.ZipFile(archive_path, mode="w") as archive:
archive.writestr("META-INF/MANIFEST.json", data=json.dumps(manifest_data, indent=2))
archive.write(code_path, arcname=_name_from(code_path))
if weights_path:
archive.write(weights_path, arcname=_name_from(weights_path))
def load(*, archive_path: Path, extract_dir: Optional[Path] = None) -> nn.Module:
"""
Load a model from a zip archive created with ``save()``, instantiate it and, if provided in the archive, load the
model's weights.
:param archive_path: directory from which to load the model archive
:param extract_dir: optional directory into which to extract the model archive (use the archive's parent directory
if not given; in any case, create a new subdirectory for the archive content)
:return: instance of the loaded model
"""
pymodule_dir = (archive_path.parent if extract_dir is None else extract_dir) / archive_path.stem
with zipfile.ZipFile(archive_path, mode="r") as archive:
archive.extractall(pymodule_dir)
manifest_data = json.loads((pymodule_dir / "META-INF" / "MANIFEST.json").read_text(encoding="utf-8"))
pymodule = _pymodule_from(pymodule_dir / manifest_data["code_file"])
model_class = _model_class_from(pymodule, manifest_data["class_name"])
model = model_class()
weights_file = manifest_data["weights_file"]
if weights_file:
model.load_state_dict(torch.load(pymodule_dir / weights_file))
model.eval()
return model
if __name__ == "__main__":
from hashlib import sha256
from textwrap import dedent
import urllib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from scipy.misc import face
from torch.nn.functional import softmax
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
base_dir = Path(__file__).parent
name, version = "mymodel", "0.0.1"
code = base_dir / "model.py"
weights = base_dir / "densenet161-8d451a50.pth"
archive = base_dir / "myarchive.zip"
index_to_name = base_dir / "index_to_name.json" # Mapping from class index to human-readable class label
# Create the model code and save it to disk on the fly (in an actual project we would not do that) -- code from
# TorchServe: examples\image_classifier\densenet_161\model.py (20201009)
code.write_text(dedent(
"""
import re
from torchvision.models.densenet import DenseNet
class ImageClassifier(DenseNet):
def __init__(self):
super().__init__(48, (6, 12, 36, 24), 96)
def load_state_dict(self, state_dict, strict=True):
pattern = re.compile(r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$")
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
return super().load_state_dict(state_dict, strict)
"""
))
# Get the trained weights, if necessary
if not weights.exists():
print("Download weights ...")
url = r"https://download.pytorch.org/models/" + weights.name
urllib.request.urlretrieve(url, filename=weights)
assert sha256(weights.read_bytes()).hexdigest()[:10] == "8d451a50ba"
# Get the index to label mapping, if necessary
if not index_to_name.exists():
print("Download mapping of class indices to class labels ...")
url = r"https://raw.githubusercontent.com/pytorch/serve/master/examples/image_classifier/index_to_name.json"
urllib.request.urlretrieve(url, filename=index_to_name)
assert sha256(index_to_name.read_bytes()).hexdigest()[:10] == "a1e7a966a1"
name_for_class_index = json.loads(index_to_name.read_text(encoding="utf-8"))
print(f"Archive {name} v{version} to {archive.name} ...")
save(name=name, version=version, code_path=code, archive_path=archive, weights_path=weights)
print(f"Reload {name} v{version} from {archive.name} ...")
model = load(archive_path=archive)
print("\nClassify demo image ``scipy.misc.face()`` ...")
input_image = Image.fromarray(face()) # Actually a racoon, but racoon is not among the trained classes
prep = Compose([Resize(256), CenterCrop(224), ToTensor(), Normalize(mean=[.49, .46, .41], std=[.23, .22, .23])])
labels = model(prep(input_image).unsqueeze_(0))
labels = softmax(labels.squeeze(0), dim=0).cpu().detach().numpy()
class_indices = np.argsort(labels)[::-1] # Class indices in order of their probability
names_and_probs = [(name_for_class_index[f"{idx}"][1], 100 * labels[idx]) for idx in class_indices]
print("\n".join(f"{name} ({prob:.2f}%)" for name, prob in names_and_probs[:5]))
plt.imshow(input_image)
plt.title(f"{names_and_probs[0][0]} ({names_and_probs[0][1]:.2f}%)")
plt.show()
@spezold
Copy link
Author

spezold commented Oct 9, 2020

What actually happens:

  • Code of an image classification model is provided in a source code file called model.py (created on the fly for simplicity)
  • Trained weights for the model are provided in a file called densenet161-8d451a50.pth (downloaded from pytorch.org, if necessary)
  • To get a textual representation of the class labels, a JSON file with index-to-name mappings is provided as index_to_name.json (downloaded from the TorchServe repository, if necessary)
  • Code, weights, and meta information are archived to a zip file called myarchive.zip
  • The zip file is extracted again
  • The extracted code is loaded as a Python module
  • The classification model is instantiated from the loaded module
  • The model instance's weights are initialized with the extracted weights from the archive
  • The model instance is used to classify a demo image (scipy.misc.face()), and the five most probable classes are shown. Unfortunately the racoon on the loaded demo image is misclassified as "badger" – but this is because racoon is not among the trained classes.

The essential functions here are save() (to archive the model) and load() (to load and instantiate the model).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment