Skip to content

Instantly share code, notes, and snippets.

@kounoike
Created May 17, 2020 05:39
Show Gist options
  • Save kounoike/c5a1ebb554abbe41960d5c06046374c7 to your computer and use it in GitHub Desktop.
Save kounoike/c5a1ebb554abbe41960d5c06046374c7 to your computer and use it in GitHub Desktop.
#!/usr/bin/python
# -*- encoding: utf-8 -*-
from logger import setup_logger
from model import BiSeNet
import torch
import os
import os.path as osp
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import cv2
from facenet_pytorch import MTCNN
def cv2pil(image):
''' OpenCV型 -> PIL型 '''
new_image = image.copy()
if new_image.ndim == 2: # モノクロ
pass
elif new_image.shape[2] == 3: # カラー
new_image = cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB)
elif new_image.shape[2] == 4: # 透過
new_image = cv2.cvtColor(new_image, cv2.COLOR_BGRA2RGBA)
new_image = Image.fromarray(new_image)
return new_image
def change_color(image, parsing, part, color=[230, 50, 20]):
b, g, r = color #[10, 50, 250] # [10, 250, 10]
tar_color = np.zeros_like(image)
tar_color[:, :, 0] = b
tar_color[:, :, 1] = g
tar_color[:, :, 2] = r
image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
tar_hsv = cv2.cvtColor(tar_color, cv2.COLOR_BGR2HSV)
image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1]
changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR)
# if part == 17:
# changed = sharpen(changed)
changed[parsing != part] = image[parsing != part]
# changed = cv2.resize(changed, (512, 512))
return changed
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
# Colors for all 20 parts
part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
[255, 0, 85], [255, 0, 170],
[0, 255, 0], [85, 255, 0], [170, 255, 0],
[0, 255, 85], [0, 255, 170],
[0, 0, 255], [85, 0, 255], [170, 0, 255],
[0, 85, 255], [0, 170, 255],
[255, 255, 0], [255, 255, 85], [255, 255, 170],
[255, 0, 255], [255, 85, 255], [255, 170, 255],
[0, 255, 255], [85, 255, 255], [170, 255, 255]]
im = np.array(im)
vis_im = im.copy().astype(np.uint8)
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
num_of_class = np.max(vis_parsing_anno)
for pi in range(1, num_of_class + 1):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
# print(vis_parsing_anno_color.shape, vis_im.shape)
vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
# Save result or not
if save_im:
cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno)
cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
return vis_im
def evaluate(cam_id=0, cp='model_final_diss.pth'):
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
save_pth = osp.join('res/cp', cp)
net.load_state_dict(torch.load(save_pth))
net.eval()
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
with torch.no_grad():
mode = 1
cam = cv2.VideoCapture(cam_id)
img_bg = np.zeros((int(cam.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cam.get(cv2.CAP_PROP_FRAME_WIDTH)), 3), dtype=np.uint8)
img_bg[:, :, 1] = 255
lut = [None, np.zeros((256, 1), dtype=np.uint8), np.zeros((256, 1), dtype=np.uint8), np.zeros((256, 1), dtype=np.uint8), np.zeros((256, 1), dtype=np.uint8)]
# atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
# 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
for idx in range(1, 20):
lut[1][idx] = 1
for idx in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17, 18]:
lut[2][idx] = 1
for idx in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]:
lut[3][idx] = 1
lut[4] = lut[3]
while True:
ret, img_cv = cam.read()
if not ret:
continue
if mode == 0:
cv2.imshow("camera", img_cv)
k = cv2.waitKey(1)
if k == 27:
break
if ord("0") <= k <= ord("9"):
mode = k - ord("0")
continue
# image = img.resize((512, 512), Image.BILINEAR)
image = cv2pil(cv2.resize(img_cv, (512, 512)))
img = torch.unsqueeze(to_tensor(image), 0).cuda()
out = net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0).astype(np.uint8)
if mode in [1, 2, 3, 4]:
resized_parsing = cv2.resize(parsing, img_cv.shape[1::-1], interpolation=cv2.INTER_NEAREST)
if mode == 3:
img_cv = change_color(img_cv, resized_parsing, 16, [0, 0, 255])
if mode == 4:
img_cv = change_color(img_cv, resized_parsing, 12, [255, 0, 0])
img_cv = change_color(img_cv, resized_parsing, 13, [255, 0, 0])
img_cv = change_color(img_cv, resized_parsing, 17, [0, 255, 255])
p_lut = cv2.LUT(resized_parsing, lut[mode])[:, :, None]
img_draw = np.where(p_lut == 1, img_cv , img_bg)
if mode == 9:
img_draw = vis_parsing_maps(image, parsing, stride=1, save_im=False)
cv2.imshow("camera", img_draw)
k = cv2.waitKey(1)
if k == 27:
break
if ord("0") <= k <= ord("9"):
mode = k - ord("0")
if __name__ == "__main__":
evaluate(cam_id=0, cp='79999_iter.pth')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment