Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active February 28, 2020 20:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ProGamerGov/4fe325efe31b5a650c24a56151dd1952 to your computer and use it in GitHub Desktop.
Save ProGamerGov/4fe325efe31b5a650c24a56151dd1952 to your computer and use it in GitHub Desktop.
This file is used to convert an MMdnn PyTorch model to a usable state dict model.
import torch
import torch.nn as nn
from collections import OrderedDict
import imp
import numpy as np
# Import the model classes that were edited. Replace 'model_class_name" with the name of the class script, and
# replace 'ModelName' with the name of the class in the script
from model_class_name import ModelName
model_name = "" # Put the model file name here
MainModel = imp.load_source('MainModel', model_name + ".py")
the_model = torch.load(model_name + ".pth")
# Replace 'ModelName' with the name of the imported model class
cnn = ModelName()
# Make sure it looks right
print(the_model)
# Convert the model to the proper format and save
the_model_sd = the_model.state_dict()
torch.save(the_model_sd, model_name + '_statedict.pth')
# Verify it works
cnn.load_state_dict(torch.load(model_name + "_statedict.pth"))
cnn.eval()
# Optionally remove '_statedict' from the new model.
@ProGamerGov
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment