Skip to content

Instantly share code, notes, and snippets.

@andcarnivorous
Created October 2, 2019 18:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save andcarnivorous/2701f6d9d27b0076153efecb9c1f51df to your computer and use it in GitHub Desktop.
Save andcarnivorous/2701f6d9d27b0076153efecb9c1f51df to your computer and use it in GitHub Desktop.
streaming image classification with mobilenet_v2 in pytorch
import torch
import cv2
import numpy as np
import json
from torchvision import transforms
from PIL import Image
model = torch.hub.load('pytorch/vision', 'mobilenet_v2', pretrained=True).cuda()
model.eval()
with open("labels.json") as labels:
labels = json.load(labels)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
cap = cv2.VideoCapture(2)
while(True):
# Capture frame-by-frame
ret, frame = cap.read()
# Our operations on the 1frame come here
frame2 = Image.fromarray(frame)
frame2 = preprocess(frame2).cuda()
output = model(frame2.unsqueeze(0))
print(output.argmax())
output = int(output.argmax().cpu().numpy())
print(labels[str(output)])
cv2.imshow("frame",frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment