Skip to content

Instantly share code, notes, and snippets.

@bparaj
Last active April 22, 2021 20:50
Show Gist options
  • Save bparaj/d3093eb88a43193a91b51cda28261208 to your computer and use it in GitHub Desktop.
Save bparaj/d3093eb88a43193a91b51cda28261208 to your computer and use it in GitHub Desktop.
Script to compute and visualize intersection over union (iou) for example rectangle pairs. Uses matplotlib.
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
def get_iou(b1, b2):
b1_x, b1_y, b1_w, b1_h = b1
b2_x, b2_y, b2_w, b2_h = b2
# Assume b1 is the bottom-left-most rectangle.
assert b1_x <= b2_x and b1_y <= b2_y
# Is there a complete overlap?
if b1_x == b2_x and b1_y == b2_y:
return 1.0
# Do they not overlap at all?
if b2_x > (b1_x + b1_w) or b2_y > (b1_y + b1_h):
return 0.0
# Is b2 completely inside b1?
if (b2_x + b2_w < b1_x + b1_w) and (b2_y + b2_h < b1_y + b1_h):
return (b2_w * b2_h) / (b1_w * b1_h)
# Now the general case.
xing_w = b2_w if (b2_x + b2_w) < (b1_x + b1_w) else (b1_x + b1_w - b2_x)
xing_h = b2_w if (b2_y + b2_h) < (b1_y + b1_h) else (b1_y + b1_h - b2_y)
xing_area = xing_w * xing_h
union_area = b1_w * b1_h + b2_w * b2_h - xing_area
return xing_area / union_area
def get_overlapping_boxes(degree):
"""
Each box is a 4-int tuple: (bot_x, bot_y, width, height)
"""
if degree == "none":
b1 = (6, 8, 40, 50)
b2 = (53, 30, 40, 50)
elif degree == "bad":
b1 = (6, 8, 40, 50)
b2 = (30, 25, 40, 50)
elif degree == "good":
b1 = (6, 8, 40, 50)
b2 = (13, 11, 40, 50)
elif degree == "wow":
b1 = (6, 8, 40, 50)
b2 = (8, 9, 40, 50)
elif degree == "perfect":
b1 = (3, 4, 40, 50)
b2 = (3, 4, 40, 50)
return b1, b2
if __name__ == "__main__":
iou_type = ["none", "bad", "good", "wow"]
# Create figure and axes
fig, ax = plt.subplots(nrows=2, ncols=2)
for idx, d in enumerate(iou_type):
b1, b2 = get_overlapping_boxes(d)
iou = get_iou(b1, b2)
rec1 = Rectangle((b1[0], b1[1]), b1[2], b1[3], ec="red", lw=2, fill=False)
rec2 = Rectangle((b2[0], b2[1]), b2[2], b2[3], ec="blue", lw=2, fill=False)
ax[idx // 2][idx % 2].add_patch(rec1)
ax[idx // 2][idx % 2].add_patch(rec2)
ax[idx // 2][idx % 2].set_xlim(0, 100)
ax[idx // 2][idx % 2].set_ylim(0, 100)
ax[idx // 2][idx % 2].set_title(f"iou = {iou:.2}")
ax[idx // 2][idx % 2].get_xaxis().set_visible(False)
ax[idx // 2][idx % 2].get_yaxis().set_visible(False)
plt.tight_layout()
fig.savefig("iou_visualizations.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment