Last active
January 8, 2021 09:38
-
-
Save rsomani95/e521ad9a365abf402c661db1cdabe5f5 to your computer and use it in GitHub Desktop.
Icevision -- `draw_sample` with uniform color map and custom font
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from icevision.visualize.draw_data import draw_bbox | |
# import a bunch of other stuff here | |
cmap_path = "zz_color_map_coco.json" | |
with open(cmap_path) as f: | |
COLOR_MAP_COCO = json.load(f) | |
COLOR_MAP_COCO = { | |
k: np.array(v).astype(np.float) for k, v in COLOR_MAP_COCO.items() | |
} | |
def as_rgb_tuple(x: Union[np.ndarray, tuple, list, str]) -> tuple: | |
"Convert np RGB values -> tuple for PIL compatibility" | |
if isinstance(x, (np.ndarray, tuple, list)): | |
if not len(x) == 3: | |
raise ValueError(f"Expected 3 (RGB) numbers, got {len(x)}") | |
if isinstance(x, np.ndarray): | |
return tuple(x.astype(np.int)) | |
elif isinstance(x, tuple): | |
return x | |
elif isinstance(x, list): | |
return tuple(x) | |
elif isinstance(x, str): | |
return PIL.ImageColor.getrgb(x) | |
else: | |
raise ValueError(f"Expected {{np.ndarray|list|tuple}}, got {type(x)}") | |
def _draw_label( | |
img: np.ndarray, | |
caption: str, | |
x: int, | |
y: int, | |
color: Union[np.ndarray, list, tuple], | |
# font_path = None ## should assign a default PIL font | |
font_path="DIN Alternate Bold.ttf", | |
font_size: int = 20, | |
return_as_pil_img: bool = False, | |
) -> Union[PIL.Image.Image, np.ndarray]: | |
"""Draw labels on the image""" | |
font = PIL.ImageFont.truetype(font_path, size=font_size) | |
xy = (x + 10, y + 5) | |
img = PIL.Image.fromarray(img) | |
draw = ImageDraw.Draw(img) | |
draw.text(xy, caption, font=font, fill=as_rgb_tuple(color)) | |
if return_as_pil_img: | |
return img | |
else: | |
return np.array(img) | |
def draw_label( | |
img: np.ndarray, | |
label: int, | |
confidence: float, | |
color, | |
# new | |
font_path: Union[str, Path] = "../fonts/DIN Alternate Bold.ttf", | |
font_size: int = 20, | |
return_as_pil_img: bool = True, | |
# | |
class_map: Optional[ClassMap] = None, | |
bbox=None, | |
mask=None, | |
font_scale: float = 1.0, | |
): | |
# finds label position based on bbox or mask | |
if bbox is not None: | |
x, y, _, _ = bbox.xyxy | |
elif mask is not None: | |
y, x = np.unravel_index(mask.data.argmax(), mask.data.shape) | |
else: | |
x, y = 0, 0 | |
if class_map is not None: | |
caption = class_map.get_id(label) | |
else: | |
caption = str(label) | |
caption = f"{caption.capitalize()}: {confidence*100:.2f}%" | |
# return _draw_label(img=img, caption=caption, x=int(x), y=int(y), color=color, | |
# font_scale=font_scale) | |
return _draw_label( | |
img=img, | |
caption=caption, | |
x=int(x), | |
y=int(y), | |
color=color, | |
font_path=font_path, | |
font_size=font_size, | |
return_as_pil_img=return_as_pil_img, | |
) | |
def draw_mask( | |
img: np.ndarray, | |
mask: MaskArray, | |
color: Tuple[int, int, int], | |
blend: float = 0.5, | |
erode_strength: int = 7, | |
): | |
color = np.asarray(color, dtype=int) | |
# draw mask | |
mask_idxs = np.where(mask.data) | |
img[mask_idxs] = blend * img[mask_idxs] + (1 - blend) * color | |
# draw border | |
border = mask.data - cv2.erode( | |
mask.data, np.ones((erode_strength, erode_strength), np.uint8), iterations=1 | |
) | |
border_idxs = np.where(border) | |
img[border_idxs] = color | |
return img | |
def draw_sample( | |
sample, | |
class_map: Optional[ClassMap] = None, | |
denormalize_fn: Optional[callable] = None, | |
display_label: bool = True, | |
display_bbox: bool = True, | |
display_mask: bool = True, | |
display_keypoints: bool = True, | |
# fontsize: int = 12, | |
# new | |
font_path: Union[str, Path] = "../fonts/DIN Alternate Bold.ttf", | |
font_size: int = 20, | |
label_color: Union[np.array, list, tuple, str] = "#C4C4C4", | |
return_as_pil_img: bool = False, | |
# | |
# bbox_thickness: int = 2, | |
mask_blend: int = 0.42, | |
erode_strength: int = 7, | |
font_scale: float = 1.0, | |
color_map: Dict[str, Tuple[int]] = COLOR_MAP_COCO, | |
# | |
exclude_labels: list = None, | |
include_only: list = None, | |
): | |
""" | |
Main function to call to plot results | |
""" | |
img = sample["img"].copy() | |
if denormalize_fn is not None: | |
img = denormalize_fn(img) | |
for label, bbox, conf, mask, keypoints in itertools.zip_longest( | |
sample.get("labels", []), | |
sample.get("bboxes", []), | |
sample.get("scores", []), | |
sample.get("masks", []), | |
sample.get("keypoints", []), | |
): | |
# random color by default | |
color = (np.random.random(3) * 0.6 + 0.4) * 255 | |
# color = np.array((24, 103, 154.)) # always blue | |
# if color-map is given and `labels` are predicted | |
# then set color accordingly | |
if label: | |
label_str = class_map.get_id(label) | |
if include_only is not None: | |
if not label_str in include_only: | |
continue | |
elif label_str in exclude_labels: | |
continue | |
if color_map is not None: | |
color = np.array(color_map[label_str]).astype(np.float) | |
if display_mask and mask is not None: | |
img = draw_mask( | |
img=img, | |
mask=mask, | |
color=color, | |
blend=mask_blend, | |
erode_strength=erode_strength, | |
) | |
if display_bbox and bbox is not None: | |
img = draw_bbox(img=img, bbox=bbox, color=color) | |
if display_keypoints and keypoints is not None: | |
img = draw_keypoints(img=img, kps=keypoints, color=color) | |
if display_label and label is not None: | |
img = draw_label( | |
img=img, | |
label=label, | |
bbox=bbox, | |
mask=mask, | |
class_map=class_map, | |
color=label_color, | |
confidence=conf, | |
font_scale=font_scale, | |
font_path=font_path, | |
font_size=font_size, | |
return_as_pil_img=return_as_pil_img, | |
) | |
return img |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"N/A": [ | |
235, | |
12, | |
255 | |
], | |
"airplane": [ | |
61, | |
230, | |
250 | |
], | |
"apple": [ | |
204, | |
5, | |
255 | |
], | |
"background": [ | |
163, | |
0, | |
255 | |
], | |
"backpack": [ | |
11, | |
200, | |
200 | |
], | |
"banana": [ | |
92, | |
255, | |
0 | |
], | |
"baseball bat": [ | |
0, | |
255, | |
173 | |
], | |
"baseball glove": [ | |
0, | |
214, | |
255 | |
], | |
"bear": [ | |
0, | |
255, | |
92 | |
], | |
"bed": [ | |
204, | |
70, | |
3 | |
], | |
"bench": [ | |
0, | |
133, | |
255 | |
], | |
"bicycle": [ | |
4, | |
200, | |
3 | |
], | |
"bird": [ | |
51, | |
255, | |
0 | |
], | |
"boat": [ | |
163, | |
255, | |
0 | |
], | |
"book": [ | |
102, | |
8, | |
255 | |
], | |
"bottle": [ | |
255, | |
5, | |
153 | |
], | |
"bowl": [ | |
255, | |
102, | |
0 | |
], | |
"broccoli": [ | |
41, | |
0, | |
255 | |
], | |
"bus": [ | |
0, | |
255, | |
20 | |
], | |
"cake": [ | |
250, | |
10, | |
15 | |
], | |
"car": [ | |
255, | |
0, | |
143 | |
], | |
"carrot": [ | |
41, | |
255, | |
0 | |
], | |
"cat": [ | |
51, | |
0, | |
255 | |
], | |
"cell phone": [ | |
173, | |
0, | |
255 | |
], | |
"chair": [ | |
140, | |
140, | |
140 | |
], | |
"clock": [ | |
255, | |
0, | |
102 | |
], | |
"couch": [ | |
8, | |
184, | |
170 | |
], | |
"cow": [ | |
0, | |
224, | |
255 | |
], | |
"cup": [ | |
0, | |
255, | |
235 | |
], | |
"dining table": [ | |
20, | |
255, | |
0 | |
], | |
"dog": [ | |
153, | |
255, | |
0 | |
], | |
"donut": [ | |
10, | |
0, | |
255 | |
], | |
"elephant": [ | |
0, | |
112, | |
255 | |
], | |
"fire hydrant": [ | |
70, | |
184, | |
160 | |
], | |
"fork": [ | |
0, | |
255, | |
133 | |
], | |
"frisbee": [ | |
255, | |
41, | |
10 | |
], | |
"giraffe": [ | |
245, | |
0, | |
255 | |
], | |
"hair drier": [ | |
6, | |
230, | |
230 | |
], | |
"handbag": [ | |
71, | |
0, | |
255 | |
], | |
"horse": [ | |
255, | |
31, | |
0 | |
], | |
"hot dog": [ | |
255, | |
235, | |
0 | |
], | |
"keyboard": [ | |
255, | |
153, | |
0 | |
], | |
"kite": [ | |
173, | |
255, | |
0 | |
], | |
"knife": [ | |
112, | |
9, | |
255 | |
], | |
"laptop": [ | |
10, | |
255, | |
71 | |
], | |
"microwave": [ | |
20, | |
0, | |
255 | |
], | |
"motorcycle": [ | |
10, | |
190, | |
212 | |
], | |
"mouse": [ | |
31, | |
0, | |
255 | |
], | |
"orange": [ | |
220, | |
220, | |
220 | |
], | |
"oven": [ | |
255, | |
163, | |
0 | |
], | |
"parking meter": [ | |
11, | |
102, | |
255 | |
], | |
"person": [ | |
156, | |
159, | |
20 | |
], | |
"pizza": [ | |
0, | |
163, | |
255 | |
], | |
"potted plant": [ | |
133, | |
0, | |
255 | |
], | |
"refrigerator": [ | |
0, | |
194, | |
255 | |
], | |
"remote": [ | |
0, | |
122, | |
255 | |
], | |
"sandwich": [ | |
120, | |
120, | |
80 | |
], | |
"scissors": [ | |
255, | |
112, | |
0 | |
], | |
"sheep": [ | |
0, | |
245, | |
255 | |
], | |
"sink": [ | |
184, | |
0, | |
255 | |
], | |
"skateboard": [ | |
0, | |
102, | |
200 | |
], | |
"skis": [ | |
255, | |
0, | |
20 | |
], | |
"snowboard": [ | |
0, | |
204, | |
255 | |
], | |
"spoon": [ | |
230, | |
230, | |
230 | |
], | |
"sports ball": [ | |
6, | |
51, | |
255 | |
], | |
"stop sign": [ | |
71, | |
255, | |
0 | |
], | |
"suitcase": [ | |
0, | |
92, | |
255 | |
], | |
"surfboard": [ | |
235, | |
255, | |
7 | |
], | |
"teddy bear": [ | |
150, | |
5, | |
61 | |
], | |
"tennis racket": [ | |
255, | |
71, | |
0 | |
], | |
"tie": [ | |
255, | |
9, | |
92 | |
], | |
"toaster": [ | |
120, | |
120, | |
120 | |
], | |
"toilet": [ | |
255, | |
255, | |
0 | |
], | |
"toothbrush": [ | |
0, | |
0, | |
255 | |
], | |
"traffic light": [ | |
0, | |
41, | |
255 | |
], | |
"train": [ | |
224, | |
255, | |
8 | |
], | |
"truck": [ | |
255, | |
122, | |
8 | |
], | |
"tv": [ | |
255, | |
0, | |
204 | |
], | |
"umbrella": [ | |
160, | |
150, | |
20 | |
], | |
"vase": [ | |
122, | |
0, | |
255 | |
], | |
"wine glass": [ | |
0, | |
10, | |
255 | |
], | |
"zebra": [ | |
0, | |
82, | |
255 | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment