Skip to content

Instantly share code, notes, and snippets.

@rsomani95
Last active January 8, 2021 09:38
Show Gist options
  • Save rsomani95/e521ad9a365abf402c661db1cdabe5f5 to your computer and use it in GitHub Desktop.
Save rsomani95/e521ad9a365abf402c661db1cdabe5f5 to your computer and use it in GitHub Desktop.
Icevision -- `draw_sample` with uniform color map and custom font
from icevision.visualize.draw_data import draw_bbox
# import a bunch of other stuff here
cmap_path = "zz_color_map_coco.json"
with open(cmap_path) as f:
COLOR_MAP_COCO = json.load(f)
COLOR_MAP_COCO = {
k: np.array(v).astype(np.float) for k, v in COLOR_MAP_COCO.items()
}
def as_rgb_tuple(x: Union[np.ndarray, tuple, list, str]) -> tuple:
"Convert np RGB values -> tuple for PIL compatibility"
if isinstance(x, (np.ndarray, tuple, list)):
if not len(x) == 3:
raise ValueError(f"Expected 3 (RGB) numbers, got {len(x)}")
if isinstance(x, np.ndarray):
return tuple(x.astype(np.int))
elif isinstance(x, tuple):
return x
elif isinstance(x, list):
return tuple(x)
elif isinstance(x, str):
return PIL.ImageColor.getrgb(x)
else:
raise ValueError(f"Expected {{np.ndarray|list|tuple}}, got {type(x)}")
def _draw_label(
img: np.ndarray,
caption: str,
x: int,
y: int,
color: Union[np.ndarray, list, tuple],
# font_path = None ## should assign a default PIL font
font_path="DIN Alternate Bold.ttf",
font_size: int = 20,
return_as_pil_img: bool = False,
) -> Union[PIL.Image.Image, np.ndarray]:
"""Draw labels on the image"""
font = PIL.ImageFont.truetype(font_path, size=font_size)
xy = (x + 10, y + 5)
img = PIL.Image.fromarray(img)
draw = ImageDraw.Draw(img)
draw.text(xy, caption, font=font, fill=as_rgb_tuple(color))
if return_as_pil_img:
return img
else:
return np.array(img)
def draw_label(
img: np.ndarray,
label: int,
confidence: float,
color,
# new
font_path: Union[str, Path] = "../fonts/DIN Alternate Bold.ttf",
font_size: int = 20,
return_as_pil_img: bool = True,
#
class_map: Optional[ClassMap] = None,
bbox=None,
mask=None,
font_scale: float = 1.0,
):
# finds label position based on bbox or mask
if bbox is not None:
x, y, _, _ = bbox.xyxy
elif mask is not None:
y, x = np.unravel_index(mask.data.argmax(), mask.data.shape)
else:
x, y = 0, 0
if class_map is not None:
caption = class_map.get_id(label)
else:
caption = str(label)
caption = f"{caption.capitalize()}: {confidence*100:.2f}%"
# return _draw_label(img=img, caption=caption, x=int(x), y=int(y), color=color,
# font_scale=font_scale)
return _draw_label(
img=img,
caption=caption,
x=int(x),
y=int(y),
color=color,
font_path=font_path,
font_size=font_size,
return_as_pil_img=return_as_pil_img,
)
def draw_mask(
img: np.ndarray,
mask: MaskArray,
color: Tuple[int, int, int],
blend: float = 0.5,
erode_strength: int = 7,
):
color = np.asarray(color, dtype=int)
# draw mask
mask_idxs = np.where(mask.data)
img[mask_idxs] = blend * img[mask_idxs] + (1 - blend) * color
# draw border
border = mask.data - cv2.erode(
mask.data, np.ones((erode_strength, erode_strength), np.uint8), iterations=1
)
border_idxs = np.where(border)
img[border_idxs] = color
return img
def draw_sample(
sample,
class_map: Optional[ClassMap] = None,
denormalize_fn: Optional[callable] = None,
display_label: bool = True,
display_bbox: bool = True,
display_mask: bool = True,
display_keypoints: bool = True,
# fontsize: int = 12,
# new
font_path: Union[str, Path] = "../fonts/DIN Alternate Bold.ttf",
font_size: int = 20,
label_color: Union[np.array, list, tuple, str] = "#C4C4C4",
return_as_pil_img: bool = False,
#
# bbox_thickness: int = 2,
mask_blend: int = 0.42,
erode_strength: int = 7,
font_scale: float = 1.0,
color_map: Dict[str, Tuple[int]] = COLOR_MAP_COCO,
#
exclude_labels: list = None,
include_only: list = None,
):
"""
Main function to call to plot results
"""
img = sample["img"].copy()
if denormalize_fn is not None:
img = denormalize_fn(img)
for label, bbox, conf, mask, keypoints in itertools.zip_longest(
sample.get("labels", []),
sample.get("bboxes", []),
sample.get("scores", []),
sample.get("masks", []),
sample.get("keypoints", []),
):
# random color by default
color = (np.random.random(3) * 0.6 + 0.4) * 255
# color = np.array((24, 103, 154.)) # always blue
# if color-map is given and `labels` are predicted
# then set color accordingly
if label:
label_str = class_map.get_id(label)
if include_only is not None:
if not label_str in include_only:
continue
elif label_str in exclude_labels:
continue
if color_map is not None:
color = np.array(color_map[label_str]).astype(np.float)
if display_mask and mask is not None:
img = draw_mask(
img=img,
mask=mask,
color=color,
blend=mask_blend,
erode_strength=erode_strength,
)
if display_bbox and bbox is not None:
img = draw_bbox(img=img, bbox=bbox, color=color)
if display_keypoints and keypoints is not None:
img = draw_keypoints(img=img, kps=keypoints, color=color)
if display_label and label is not None:
img = draw_label(
img=img,
label=label,
bbox=bbox,
mask=mask,
class_map=class_map,
color=label_color,
confidence=conf,
font_scale=font_scale,
font_path=font_path,
font_size=font_size,
return_as_pil_img=return_as_pil_img,
)
return img
{
"N/A": [
235,
12,
255
],
"airplane": [
61,
230,
250
],
"apple": [
204,
5,
255
],
"background": [
163,
0,
255
],
"backpack": [
11,
200,
200
],
"banana": [
92,
255,
0
],
"baseball bat": [
0,
255,
173
],
"baseball glove": [
0,
214,
255
],
"bear": [
0,
255,
92
],
"bed": [
204,
70,
3
],
"bench": [
0,
133,
255
],
"bicycle": [
4,
200,
3
],
"bird": [
51,
255,
0
],
"boat": [
163,
255,
0
],
"book": [
102,
8,
255
],
"bottle": [
255,
5,
153
],
"bowl": [
255,
102,
0
],
"broccoli": [
41,
0,
255
],
"bus": [
0,
255,
20
],
"cake": [
250,
10,
15
],
"car": [
255,
0,
143
],
"carrot": [
41,
255,
0
],
"cat": [
51,
0,
255
],
"cell phone": [
173,
0,
255
],
"chair": [
140,
140,
140
],
"clock": [
255,
0,
102
],
"couch": [
8,
184,
170
],
"cow": [
0,
224,
255
],
"cup": [
0,
255,
235
],
"dining table": [
20,
255,
0
],
"dog": [
153,
255,
0
],
"donut": [
10,
0,
255
],
"elephant": [
0,
112,
255
],
"fire hydrant": [
70,
184,
160
],
"fork": [
0,
255,
133
],
"frisbee": [
255,
41,
10
],
"giraffe": [
245,
0,
255
],
"hair drier": [
6,
230,
230
],
"handbag": [
71,
0,
255
],
"horse": [
255,
31,
0
],
"hot dog": [
255,
235,
0
],
"keyboard": [
255,
153,
0
],
"kite": [
173,
255,
0
],
"knife": [
112,
9,
255
],
"laptop": [
10,
255,
71
],
"microwave": [
20,
0,
255
],
"motorcycle": [
10,
190,
212
],
"mouse": [
31,
0,
255
],
"orange": [
220,
220,
220
],
"oven": [
255,
163,
0
],
"parking meter": [
11,
102,
255
],
"person": [
156,
159,
20
],
"pizza": [
0,
163,
255
],
"potted plant": [
133,
0,
255
],
"refrigerator": [
0,
194,
255
],
"remote": [
0,
122,
255
],
"sandwich": [
120,
120,
80
],
"scissors": [
255,
112,
0
],
"sheep": [
0,
245,
255
],
"sink": [
184,
0,
255
],
"skateboard": [
0,
102,
200
],
"skis": [
255,
0,
20
],
"snowboard": [
0,
204,
255
],
"spoon": [
230,
230,
230
],
"sports ball": [
6,
51,
255
],
"stop sign": [
71,
255,
0
],
"suitcase": [
0,
92,
255
],
"surfboard": [
235,
255,
7
],
"teddy bear": [
150,
5,
61
],
"tennis racket": [
255,
71,
0
],
"tie": [
255,
9,
92
],
"toaster": [
120,
120,
120
],
"toilet": [
255,
255,
0
],
"toothbrush": [
0,
0,
255
],
"traffic light": [
0,
41,
255
],
"train": [
224,
255,
8
],
"truck": [
255,
122,
8
],
"tv": [
255,
0,
204
],
"umbrella": [
160,
150,
20
],
"vase": [
122,
0,
255
],
"wine glass": [
0,
10,
255
],
"zebra": [
0,
82,
255
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment