Last active
January 10, 2019 08:32
-
-
Save kris-singh/219f2df5acb26fd4b07459df9b4ab44d to your computer and use it in GitHub Desktop.
SoftNMS
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
#include "cpu/vision.h" | |
template <typename scalar_t> | |
at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, | |
at::Tensor& scores, | |
const float threshold) { | |
AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor"); | |
AT_ASSERTM(!scores.type().is_cuda(), "scores must be a CPU tensor"); | |
AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores"); | |
if (dets.numel() == 0) { | |
return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); | |
} | |
auto x1_t = dets.select(1, 0).contiguous(); | |
auto y1_t = dets.select(1, 1).contiguous(); | |
auto x2_t = dets.select(1, 2).contiguous(); | |
auto y2_t = dets.select(1, 3).contiguous(); | |
at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1); | |
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); | |
auto ndets = dets.size(0); | |
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte).device(at::kCPU)); | |
at::Tensor included_t = at::zeros({ndets}, dets.options().dtype(at::kByte).device(at::kCPU)); | |
auto suppressed = suppressed_t.data<uint8_t>(); | |
auto included = included_t.data<uint8_t>(); | |
auto order = order_t.data<int64_t>(); | |
auto x1 = x1_t.data<scalar_t>(); | |
auto y1 = y1_t.data<scalar_t>(); | |
auto x2 = x2_t.data<scalar_t>(); | |
auto y2 = y2_t.data<scalar_t>(); | |
auto areas = areas_t.data<scalar_t>(); | |
float sigma = 0.1; | |
// auto score = scores.data<scalar_t>(); | |
for (int64_t _i = 0; _i < ndets; _i++) { | |
auto i = scores.argmax().data<int64_t>()[0]; | |
if (suppressed[i] == 1 || included[i] == 1) | |
continue; | |
included[i] = 1; | |
scores[i] = 1e-12; | |
auto ix1 = x1[i]; | |
auto iy1 = y1[i]; | |
auto ix2 = x2[i]; | |
auto iy2 = y2[i]; | |
auto iarea = areas[i]; | |
for (int64_t _j = 0; _j < ndets; _j++) { | |
auto j = order[_j]; | |
if (suppressed[j] == 1 || included[i] == 1) | |
continue; | |
auto xx1 = std::max(ix1, x1[j]); | |
auto yy1 = std::max(iy1, y1[j]); | |
auto xx2 = std::min(ix2, x2[j]); | |
auto yy2 = std::min(iy2, y2[j]); | |
auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1 + 1); | |
auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1 + 1); | |
auto inter = w * h; | |
auto ovr = inter / (iarea + areas[j] - inter); | |
scores[j] = scores[j] * std::exp(- std::pow(ovr, 2.0) / sigma); | |
if (scores.data<float>()[j] >= threshold) | |
suppressed[j] = 1; | |
} | |
} | |
return at::nonzero(suppressed_t == 0).squeeze(1); | |
} | |
at::Tensor soft_nms_cpu(const at::Tensor& dets, | |
at::Tensor& scores, | |
const float threshold) { | |
at::Tensor result; | |
AT_DISPATCH_FLOATING_TYPES(dets.type(), "soft_nms", [&] { | |
result = soft_nms_cpu_kernel<scalar_t>(dets, scores, threshold); | |
}); | |
return result; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
#pragma once | |
#include <torch/extension.h> | |
at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, | |
const at::Tensor& rois, | |
const float spatial_scale, | |
const int pooled_height, | |
const int pooled_width, | |
const int sampling_ratio); | |
at::Tensor nms_cpu(const at::Tensor& dets, | |
const at::Tensor& scores, | |
const float threshold); | |
at::Tensor soft_nms_cpu(const at::Tensor& dets, | |
at::Tensor& scores, | |
const float threshold); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
#pragma once | |
#include "cpu/vision.h" | |
#ifdef WITH_CUDA | |
#include "cuda/vision.h" | |
#endif | |
at::Tensor soft_nms(const at::Tensor& dets, | |
at::Tensor& scores, | |
const float threshold) { | |
if (dets.type().is_cuda()) { | |
#ifdef WITH_CUDA | |
AT_ERROR("Soft NMS Does Not have GPU support"); | |
#endif | |
} | |
at::Tensor result = soft_nms_cpu(dets, scores, threshold); | |
return result; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
#include "nms.h" | |
#include "ROIAlign.h" | |
#include "ROIPool.h" | |
#include "soft_nms.h" | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("soft_nms", &soft_nms, "soft-non-maximum suppression"); | |
m.def("nms", &nms, "non-maximum suppression"); | |
m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward"); | |
m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward"); | |
m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward"); | |
m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward"); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import torch | |
from .batch_norm import FrozenBatchNorm2d | |
from .misc import Conv2d | |
from .misc import ConvTranspose2d | |
from .misc import interpolate | |
from .nms import nms | |
from .soft_nms import soft_nms | |
from .roi_align import ROIAlign | |
from .roi_align import roi_align | |
from .roi_pool import ROIPool | |
from .roi_pool import roi_pool | |
from .smooth_l1_loss import smooth_l1_loss | |
__all__ = ["nms", "soft_nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool", "smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate", "FrozenBatchNorm2d"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
# from ._utils import _C | |
from maskrcnn_benchmark import _C | |
soft_nms = _C.soft_nms | |
# nms.__doc__ = """ | |
# This function performs Non-maximum suppresion""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment