Skip to content

Instantly share code, notes, and snippets.

@dnutiu
Created July 1, 2023 07:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dnutiu/a6572f61a6174f3710ec921b624d9def to your computer and use it in GitHub Desktop.
Save dnutiu/a6572f61a6174f3710ec921b624d9def to your computer and use it in GitHub Desktop.
Pytorch model prediction
import torch
import torchvision.transforms as transforms
model_path = "models/resnet18_5_epochs.pth"
model = torch.load(model_path, map_location=torch.device('cpu'))
model.eval()
# transform image to tensor
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
from PIL import Image
img = Image.open("Downloads/4.jpg")
x = transform(img)
x = x.unsqueeze(0) # add batch dimension
output = model(x)
# load categories
with open("models/categories.txt", "r") as f:
categories = [s.strip() for s in f.readlines()]
# print top 5 predictions
data, indices = torch.sort(output, descending=True)
# predicated categories for treshold > 0.5
for i in range(50):
if data[0][i] > -0.5:
print(categories[indices[0][i]], data[0][i].item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment