Skip to content

Instantly share code, notes, and snippets.

@berak
Last active May 17, 2021 09:29
Show Gist options
  • Save berak/c5d58315a332ba2bf3246b3ad2686c4c to your computer and use it in GitHub Desktop.
Save berak/c5d58315a332ba2bf3246b3ad2686c4c to your computer and use it in GitHub Desktop.
weighted box fusion
template<class _Tp>
struct WeightedBoxesFusion {
typedef Rect_<_Tp> RECT;
struct Box {
int label;
float x1,y1,x2,y2, score, weight;
Box() {}
Box(int label, float x1,float y1,float x2,float y2, float score, float weight)
: label(label), x1(x1), y1(y1), x2(x2), y2(y2), score(score), weight(weight)
{
}
double distance(const Box &b) const {
RECT r1(x1,y1,x2-x1,y2-y1);
RECT r2(b.x1,b.y1,b.x2-b.x1,b.y2-b.y1);
return jaccardDistance(r1,r2);
}
};
WeightedBoxesFusion(double iou_thresh=0.45) : THR_IOU(iou_thresh) {}
vector<Box> B;
vector<pair<vector<Box>,Box>> FL; // combined F and L lists
int num_models = 0;
double weights = 0;
double THR_IOU = (1.0 - 0.55); // jaccardDistance returns (1-iou)
bool addModel(const vector<RECT> &rects, const vector<float> &scores, const vector<int> &labels, float weight, float conf_thresh) {
CV_Assert(rects.size()==scores.size());
CV_Assert(scores.size()==labels.size());
for (size_t i=0; i<rects.size(); i++) {
if (scores[i] < conf_thresh)
continue;
B.push_back(Box(labels[i],
rects[i].x, rects[i].y, rects[i].x+rects[i].width, rects[i].y+rects[i].height,
scores[i] * weight, weight));
}
num_models ++;
weights += weight;
return true;
}
bool fuse(vector<RECT> &rects, vector<float> &scores, vector<int> &labels) {
sort(B.begin(), B.end(), [](const Box &a, const Box &b) {
return a.score > b.score;
});
// 3.
for (const auto b : B) {
double min_d = THR_IOU;
int best = -1;
for (int i=0; i<FL.size(); i++) {
auto &f = FL[i];
if (b.label != f.second.label)
continue;
double d = b.distance(f.second);
if (d < min_d) {
best = i;
min_d = d;
}
}
if (best != -1) {
FL[best].first.push_back(b);
} else {
FL.push_back(make_pair(vector<Box>{b},b));
}
}
// 6.
for (auto f : FL) {
float x1=0, x2=0, y1=0, y2=0, sum_score=0;
for (auto q : f.first) {
x1 += q.score * q.x1; y1 += q.score * q.y1;
x2 += q.score * q.x2; y2 += q.score * q.y2;
sum_score += q.score;
}
x1 /= sum_score;
y1 /= sum_score;
x2 /= sum_score;
y2 /= sum_score;
rects.push_back(RECT(x1,y1,x2-x1,y2-y1));
scores.push_back(sum_score / weights);
labels.push_back(f.second.label);
}
return true;
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment