Skip to content

Instantly share code, notes, and snippets.

@wkcn
Last active March 3, 2018 09:29
Show Gist options
  • Save wkcn/4a09c142bc9886b45b5a23461bbe4733 to your computer and use it in GitHub Desktop.
Save wkcn/4a09c142bc9886b45b5a23461bbe4733 to your computer and use it in GitHub Desktop.
Testing Code for MultiProposal
import mxnet as mx
from mxnet import nd
import numpy as np
import time
np.random.seed(1234)
feature_stride = 16
scales = (8, 16, 32)
ratios = (0.5, 1, 2)
rpn_pre_nms_top_n = 12000
rpn_post_nms_top_n = 2000
threshold = 0.5
rpn_min_size = 16
batch_size = 16
feat_len = 128
H, W = feat_len, feat_len
num_anchors = len(scales) * len(ratios)
count_anchors = feat_len * feat_len * num_anchors
T = 0.0
'''
cls_prob: (batch_size, 2 * num_anchors, H, W)
bbox_pred: (batch_size, 4 * num_anchors, H, W)
im_info: (batch_size, 3)
'''
ctx = mx.cpu(0)
im_info = nd.empty((batch_size, 3), dtype = np.float32,ctx = ctx)
cls_prob = nd.random.uniform(shape = (batch_size, 2 * num_anchors, H, W), ctx = ctx)
bbox_pred = nd.random.uniform(shape = (batch_size, 4 * num_anchors, H, W), ctx = ctx)
print ("start")
for i in range(batch_size):
im_len = feat_len * feature_stride
im_size = np.array([im_len, im_len]) #np.random.randint(100, feat_len * feature_stride, size = (2,))
im_scale = np.random.randint(70, 100) / 100.0
im_info[i, :] = [im_size[0], im_size[1], im_scale]
def test(rpn_pre_nms_top_n, rpn_post_nms_top_n):
oldtime = time.time()
multi_proposal, multi_score = mx.nd.contrib.MultiProposal(
cls_score = cls_prob,
bbox_pred = bbox_pred,
im_info = im_info,
feature_stride = feature_stride,
scales = scales,
ratios = ratios,
rpn_pre_nms_top_n = rpn_pre_nms_top_n,
rpn_post_nms_top_n = rpn_post_nms_top_n,
threshold = threshold,
rpn_min_size = rpn_min_size, output_score = True)
nd.waitall()
dt = time.time() - oldtime
global T
T += dt
print ("Testing...")
lastT = T
for i in range(10):
test(rpn_pre_nms_top_n, rpn_post_nms_top_n)
test(rpn_pre_nms_top_n, 1500)
test(1000, 500)
print (T, T - lastT)
lastT = T
print ("test ok")
print (T)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment