Created
May 26, 2022 07:32
-
-
Save ditwoo/5bd26a09355e3132ede9c6f7bdfba103 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def draw_msra_gaussian(heatmap, channel, center, sigma=2): | |
"""Draw a gaussian on heatmap channel (inplace function). | |
Args: | |
heatmap (np.ndarray): heatmap matrix, expected shapes [C, W, H]. | |
channel (int): channel to use for drawing a gaussian. | |
center (Tuple[int, int]): gaussian center coordinates. | |
sigma (float): gaussian size. Default is ``2``. | |
""" | |
tmp_size = sigma * 6 | |
mu_x = int(center[0] + 0.5) | |
mu_y = int(center[1] + 0.5) | |
_, w, h = heatmap.shape | |
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] | |
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] | |
if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0: | |
return heatmap | |
size = 2 * tmp_size + 1 | |
x = np.arange(0, size, 1, np.float32) | |
y = x[:, np.newaxis] | |
x0 = y0 = size // 2 | |
g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma**2)) | |
g_x = (max(0, -ul[0]), min(br[0], h) - ul[0]) | |
g_y = (max(0, -ul[1]), min(br[1], w) - ul[1]) | |
img_x = (max(0, ul[0]), min(br[0], h)) | |
img_y = (max(0, ul[1]), min(br[1], w)) | |
# fmt: off | |
heatmap[channel, img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum( | |
heatmap[channel, img_y[0]:img_y[1], img_x[0]:img_x[1]], | |
g[g_y[0]:g_y[1], g_x[0]:g_x[1]], | |
) | |
# fmt: on | |
# usage: | |
num_channels = 10 | |
heatmap_height, heatmap_width = 400, 400 | |
heatmap = np.zeros((num_channels, heatmap_height, heatmap_width), dtype=np.float32) | |
for (x1, y1, x2, y2), channel in zip(boxes, channels): | |
w = abs(x2 - x1) | |
h = abs(y2 - y1) | |
xc = x1 + w / 2 | |
yc = y1 + h / 2 | |
scaled_xc = int(xc * heatmap_width) | |
scaled_yc = int(yc * heatmap_height) | |
# draw class centers | |
draw_msra_gaussian(heatmap, channel, (scaled_xc, scaled_yc), sigma=np.clip(w * h, 2, 4)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment