Skip to content

Instantly share code, notes, and snippets.

@dpoulopoulos
Created January 13, 2022 08:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dpoulopoulos/2c545f342354b30b7b592c67746e531b to your computer and use it in GitHub Desktop.
Save dpoulopoulos/2c545f342354b30b7b592c67746e531b to your computer and use it in GitHub Desktop.
from PIL import Image
from torchvision.prototype import models as pm
img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Load a pre-trained model.
# In this step we will load a ResNet architecture.
weights = pm.ResNet50_Weights.ImageNet1K_V1
model = pm.resnet50(weights=weights)
# Set the model in evaluation mode.
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Get the predictions on the (processed)
# test dataset.
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
# Step 4: Print a human-readable output
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}*%*")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment