Last active
May 1, 2019 05:45
-
-
Save dalmia/0934f019eff262dafd4a91ad0a720448 to your computer and use it in GitHub Desktop.
SSD NMS
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# loc, conf, priors defined earlier | |
# decode function take as input the default box coordinates and the predicted | |
# offsets to give the box coordinates with respect to the image | |
# function definition for nms and decode can be found here: https://github.com/amdegroot/ssd.pytorch/blob/master/layers/box_utils.py | |
class Detect(Function): | |
"""At test time, Detect is the final layer of SSD. Decode location preds, | |
apply non-maximum suppression to location predictions based on conf | |
scores and threshold to a top_k number of output predictions for both | |
confidence score and locations. | |
""" | |
def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh): | |
self.num_classes = num_classes | |
self.background_label = bkg_label | |
self.top_k = top_k | |
# Parameters used in nms. | |
self.nms_thresh = nms_thresh | |
self.conf_thresh = conf_thresh | |
self.variance = cfg['variance'] | |
def forward(self, loc_data, conf_data, prior_data): | |
""" | |
Args: | |
loc_data: (tensor) Loc preds from loc layers | |
Shape: [batch,num_priors*4] | |
conf_data: (tensor) Shape: Conf preds from conf layers | |
Shape: [batch*num_priors,num_classes] | |
prior_data: (tensor) Prior boxes and variances from priorbox layers | |
Shape: [1,num_priors,4] | |
""" | |
num = loc_data.size(0) # batch size | |
num_priors = prior_data.size(0) | |
output = torch.zeros(num, self.num_classes, self.top_k, 5) | |
conf_preds = conf_data.view(num, num_priors, | |
self.num_classes).transpose(2, 1) | |
# Decode predictions into bboxes. | |
for i in range(num): | |
decoded_boxes = decode(loc_data[i], prior_data, self.variance) | |
# For each class, perform nms | |
conf_scores = conf_preds[i].clone() | |
for cl in range(1, self.num_classes): | |
c_mask = conf_scores[cl].gt(self.conf_thresh) | |
scores = conf_scores[cl][c_mask] | |
if scores.size(0) == 0: | |
continue | |
l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes) | |
boxes = decoded_boxes[l_mask].view(-1, 4) | |
# idx of highest scoring and non-overlapping boxes per class | |
ids, count = nms(boxes, scores, self.nms_thresh, self.top_k) | |
output[i, cl, :count] = \ | |
torch.cat((scores[ids[:count]].unsqueeze(1), | |
boxes[ids[:count]]), 1) | |
flt = output.contiguous().view(num, -1, 5) | |
_, idx = flt[:, :, 0].sort(1, descending=True) | |
_, rank = idx.sort(1) | |
flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0) | |
return output | |
detect = Detect(num_classes, 0, 200, 0.01, 0.45) | |
output = detect(loc.view(loc.size(0), -1, 4), # loc preds | |
softmax(conf.view(conf.size(0), -1, | |
self.num_classes)), # conf preds | |
priors) # default boxes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment