Skip to content

Instantly share code, notes, and snippets.

@zhreshold
Created January 14, 2019 23:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zhreshold/9c28c176908c4289dfa5572140ed2dd6 to your computer and use it in GitHub Desktop.
Save zhreshold/9c28c176908c4289dfa5572140ed2dd6 to your computer and use it in GitHub Desktop.
Demo script with gluon network -> symbol -> SymbolBlock
"""SSD Demo script."""
import os
import argparse
import mxnet as mx
import gluoncv as gcv
from gluoncv.data.transforms import presets
from matplotlib import pyplot as plt
def parse_args():
parser = argparse.ArgumentParser(description='Test with SSD networks.')
parser.add_argument('--network', type=str, default='ssd_300_vgg16_atrous_voc',
help="Base network name")
parser.add_argument('--images', type=str, default='',
help='Test images, use comma to split multiple.')
parser.add_argument('--gpus', type=str, default='',
help='Training with GPUs, you can specify 1,3 for example.')
parser.add_argument('--pretrained', type=str, default='True',
help='Load weights from previously saved parameters.')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
# context list
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
ctx = [mx.cpu()] if not ctx else ctx
# grab some image if not specified
if not args.images.strip():
gcv.utils.download("https://cloud.githubusercontent.com/assets/3307514/" +
"20012568/cbc2d6f6-a27d-11e6-94c3-d35a9cb47609.jpg", 'street.jpg')
image_list = ['street.jpg']
else:
image_list = [x.strip() for x in args.images.split(',') if x.strip()]
if args.pretrained.lower() in ['true', '1', 'yes', 't']:
net = gcv.model_zoo.get_model(args.network, pretrained=True)
else:
net = gcv.model_zoo.get_model(args.network, pretrained=False, pretrained_base=False)
net.load_parameters(args.pretrained)
net.set_nms(0.45, 200)
# export to json and load back with SymbolBlock
class_names = net.classes
print('Export to JSON...')
gcv.utils.export_block(args.network, net, preprocess=False, layout='CHW')
print('Load back from JSON with SymbolBlock')
net = mx.gluon.SymbolBlock.imports('{}-symbol.json'.format(args.network),
['data'], '{}-0000.params'.format(args.network))
net.collect_params().reset_ctx(ctx = ctx)
ax = None
for image in image_list:
x, img = presets.ssd.load_test(image, short=512)
x = x.as_in_context(ctx[0])
ids, scores, bboxes = [xx[0].asnumpy() for xx in net(x)]
ax = gcv.utils.viz.plot_bbox(img, bboxes, scores, ids,
class_names=class_names, ax=ax)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment