Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
img_dir = './images'
transform_fn = transforms.Compose([
transforms.Normalize([.485, .456, .406], [.229, .224, .225])
def get_image(img_path, ctx):
img = mx.image.imread(img_path)
transform_img = transform_fn(img)
expand_img = transform_img.expand_dims(0).as_in_context(ctx)
return expand_img
def build_data_iter(root_dir, batch_size, ctx, group=2):
img_names = os.listdir(root_dir)
img_iter = []
for i in range(group):
group_img = []
for img_name in img_names[i*batch_size: (i+1)*batch_size]:
# print 'img_name: ', img_name
img_path = os.path.join(img_dir, img_name)
img = get_image(img_path, ctx)
group_img = mx.nd.concat(*group_img, dim=0)
return img_iter
model = gluoncv.model_zoo.get_model('psp_resnet101_ade', pretrained=True)
data_iter = build_data_iter(img_dir, batch_size=8, ctx=ctx, group=1)
for databatch in data_iter:
output, test = model(databatch)
predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
mask = get_color_pallete(predict, 'ade20k')'oframe6.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.