Last active
July 17, 2021 18:15
-
-
Save Jeanvit/401af6fa2235bbdf08405884d30cd24d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.onnx | |
import torchvision | |
import torchvision.models as models | |
import sys | |
onnx_model_path = "" | |
sample_image = "" | |
if len(sys.argv) == 3: | |
onnx_model_path = sys.argv[1] | |
sample_image = sys.argv[2] | |
else: | |
print ("Please provide 2 arguments: onnxFileName sampleImagePath") | |
sys.exit() | |
# https://pytorch.org/hub/pytorch_vision_densenet/ | |
model = torch.hub.load('pytorch/vision:v0.6.0', 'densenet121', pretrained=True) | |
# set the model to inference mode | |
model.eval() | |
# Create some sample input in the shape this model expects | |
# This is needed because the convertion forward pass the network once | |
dummy_input = torch.randn(1, 3, 224, 224) | |
torch.onnx.export(model, dummy_input, onnx_model_path, verbose=True) | |
import cv2 | |
import numpy as np | |
net = cv2.dnn.readNetFromONNX(onnx_model_path) | |
image = cv2.imread(sample_image) | |
blob = cv2.dnn.blobFromImage(image, 1.0 / 255, (224, 224),(0, 0, 0), swapRB=True, crop=False) | |
net.setInput(blob) | |
preds = net.forward() | |
biggest_pred_index = np.array(preds)[0].argmax() | |
print ("Predicted class:",biggest_pred_index) | |
import requests | |
LABELS_URL = 'https://s3.amazonaws.com/outcome-blog/imagenet/labels.json' | |
labels = {int(key):value for (key, value) | |
in requests.get(LABELS_URL).json().items()} | |
print("The class",biggest_pred_index, "corresponds to", labels[biggest_pred_index]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment