Skip to content

Instantly share code, notes, and snippets.

@giangnguyen2412
Created September 25, 2020 04:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save giangnguyen2412/ede6391ccb7b0328ca6aa14e03a0a479 to your computer and use it in GitHub Desktop.
Save giangnguyen2412/ede6391ccb7b0328ca6aa14e03a0a479 to your computer and use it in GitHub Desktop.
Convert pretrained Resnet50 model from Pytorch to Tensorflow using ONNX
## Run this on python CLI
import torch
import torchvision
torch.set_num_threads(1)
from torchvision.models import *
from visualisation.core.utils import device, image_net_postprocessing
from torch import nn
from operator import itemgetter
from visualisation.core.utils import imshow
import glob
import matplotlib.pyplot as plt
import numpy as np
import torch
from utils import *
from PIL import Image
test_image_paths = glob.glob('/home/dexter/Downloads/Exp1/Exp1-1 100/*.*')
from torchvision.transforms import ToTensor, Resize, Compose, ToPILImage
from visualisation.core import *
from visualisation.core.utils import image_net_preprocessing
size= 224
# Pre-process the image and convert into a tensor
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(size),
torchvision.transforms.CenterCrop(size),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
trained_model = resnet50(pretrained=True)
img = Image.open(test_image_paths[8])
x = transform(img).unsqueeze(0)
torch.onnx.export(trained_model, x, 'resnet50.onnx')
## Disable CLI and run
import onnx
from onnx_tf.backend import prepare
model = onnx.load('resnet50.onnx')
tf_rep = prepare(model)
import torch
import torchvision
torch.set_num_threads(1)
from torchvision.models import *
from visualisation.core.utils import device, image_net_postprocessing
from torch import nn
from operator import itemgetter
from visualisation.core.utils import imshow
import glob
import matplotlib.pyplot as plt
import numpy as np
import torch
from utils import *
from PIL import Image
# You dir to test image here. We expect the original model and the converted model give the same label on the same images
test_image_paths = glob.glob('/home/dexter/Downloads/Exp1/Exp1-1 100/*.*')
from torchvision.transforms import ToTensor, Resize, Compose, ToPILImage
from visualisation.core import *
from visualisation.core.utils import image_net_preprocessing
size= 224
# Pre-process the image and convert into a tensor
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(size),
torchvision.transforms.CenterCrop(size),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
model = resnet50(pretrained=True).to(device)
model.eval()
# You test image
img = Image.open(test_image_paths[10])
x = transform(img).unsqueeze(0).to(device)
out = model(x)
p = torch.nn.functional.softmax(out, dim=1)
score, index = torch.topk(p, 1)
input_category_id = index[0][0].item() # 716
predicted_confidence = score[0][0].item()
## Now run the converted model
output = tf_rep.run(x.cpu())
np.argmax(output) # 716
tf_rep.export_graph('resnet50.pb') # Save the model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment