Skip to content

Instantly share code, notes, and snippets.

@jackdh
Created December 18, 2018 11:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jackdh/3e23172c41f9cea64e73f021b7c1827f to your computer and use it in GitHub Desktop.
Save jackdh/3e23172c41f9cea64e73f021b7c1827f to your computer and use it in GitHub Desktop.
img_dir = './images'
transform_fn = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([.485, .456, .406], [.229, .224, .225])
])
def get_image(img_path, ctx):
img = mx.image.imread(img_path)
transform_img = transform_fn(img)
expand_img = transform_img.expand_dims(0).as_in_context(ctx)
return expand_img
def build_data_iter(root_dir, batch_size, ctx, group=2):
img_names = os.listdir(root_dir)
img_names.sort()
img_iter = []
for i in range(group):
group_img = []
for img_name in img_names[i*batch_size: (i+1)*batch_size]:
# print 'img_name: ', img_name
img_path = os.path.join(img_dir, img_name)
img = get_image(img_path, ctx)
group_img.append(img)
group_img = mx.nd.concat(*group_img, dim=0)
img_iter.append(group_img)
return img_iter
model = gluoncv.model_zoo.get_model('psp_resnet101_ade', pretrained=True)
data_iter = build_data_iter(img_dir, batch_size=8, ctx=ctx, group=1)
for databatch in data_iter:
output, test = model(databatch)
predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
mask = get_color_pallete(predict, 'ade20k')
mask.save('oframe6.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment