Skip to content

Instantly share code, notes, and snippets.

@Jeanvit
Last active July 17, 2021 18:15
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Jeanvit/401af6fa2235bbdf08405884d30cd24d to your computer and use it in GitHub Desktop.
Save Jeanvit/401af6fa2235bbdf08405884d30cd24d to your computer and use it in GitHub Desktop.
import torch
import torch.onnx
import torchvision
import torchvision.models as models
import sys
onnx_model_path = ""
sample_image = ""
if len(sys.argv) == 3:
onnx_model_path = sys.argv[1]
sample_image = sys.argv[2]
else:
print ("Please provide 2 arguments: onnxFileName sampleImagePath")
sys.exit()
# https://pytorch.org/hub/pytorch_vision_densenet/
model = torch.hub.load('pytorch/vision:v0.6.0', 'densenet121', pretrained=True)
# set the model to inference mode
model.eval()
# Create some sample input in the shape this model expects
# This is needed because the convertion forward pass the network once
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, onnx_model_path, verbose=True)
import cv2
import numpy as np
net = cv2.dnn.readNetFromONNX(onnx_model_path)
image = cv2.imread(sample_image)
blob = cv2.dnn.blobFromImage(image, 1.0 / 255, (224, 224),(0, 0, 0), swapRB=True, crop=False)
net.setInput(blob)
preds = net.forward()
biggest_pred_index = np.array(preds)[0].argmax()
print ("Predicted class:",biggest_pred_index)
import requests
LABELS_URL = 'https://s3.amazonaws.com/outcome-blog/imagenet/labels.json'
labels = {int(key):value for (key, value)
in requests.get(LABELS_URL).json().items()}
print("The class",biggest_pred_index, "corresponds to", labels[biggest_pred_index])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment