Skip to content

Instantly share code, notes, and snippets.

@masahi
Created December 22, 2020 11:20
Show Gist options
  • Save masahi/ea002c85e7d665d40eeb5c6422490e63 to your computer and use it in GitHub Desktop.
Save masahi/ea002c85e7d665d40eeb5c6422490e63 to your computer and use it in GitHub Desktop.
import numpy as np
import cv2
import torch
import torchvision
in_size = 300
input_shape = (1, 3, in_size, in_size)
def download(url, path):
import urllib.request as urllib2
urllib2.urlretrieve(url, path)
def get_input():
img_path = "test_street_small.jpg"
img_url = (
"https://raw.githubusercontent.com/dmlc/web-data/" "master/gluoncv/detection/street_small.jpg"
)
download(img_url, img_path)
img = cv2.imread(img_path).astype("float32")
img = cv2.resize(img, (in_size, in_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img / 255.0, [2, 0, 1])
img = np.expand_dims(img, axis=0)
return img
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
img = get_input()
inp = torch.from_numpy(img)
with torch.no_grad():
out = model(inp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment