Created
May 17, 2020 05:39
-
-
Save kounoike/c5a1ebb554abbe41960d5c06046374c7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# -*- encoding: utf-8 -*- | |
from logger import setup_logger | |
from model import BiSeNet | |
import torch | |
import os | |
import os.path as osp | |
import numpy as np | |
from PIL import Image | |
import torchvision.transforms as transforms | |
import cv2 | |
from facenet_pytorch import MTCNN | |
def cv2pil(image): | |
''' OpenCV型 -> PIL型 ''' | |
new_image = image.copy() | |
if new_image.ndim == 2: # モノクロ | |
pass | |
elif new_image.shape[2] == 3: # カラー | |
new_image = cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB) | |
elif new_image.shape[2] == 4: # 透過 | |
new_image = cv2.cvtColor(new_image, cv2.COLOR_BGRA2RGBA) | |
new_image = Image.fromarray(new_image) | |
return new_image | |
def change_color(image, parsing, part, color=[230, 50, 20]): | |
b, g, r = color #[10, 50, 250] # [10, 250, 10] | |
tar_color = np.zeros_like(image) | |
tar_color[:, :, 0] = b | |
tar_color[:, :, 1] = g | |
tar_color[:, :, 2] = r | |
image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) | |
tar_hsv = cv2.cvtColor(tar_color, cv2.COLOR_BGR2HSV) | |
image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1] | |
changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR) | |
# if part == 17: | |
# changed = sharpen(changed) | |
changed[parsing != part] = image[parsing != part] | |
# changed = cv2.resize(changed, (512, 512)) | |
return changed | |
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'): | |
# Colors for all 20 parts | |
part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], | |
[255, 0, 85], [255, 0, 170], | |
[0, 255, 0], [85, 255, 0], [170, 255, 0], | |
[0, 255, 85], [0, 255, 170], | |
[0, 0, 255], [85, 0, 255], [170, 0, 255], | |
[0, 85, 255], [0, 170, 255], | |
[255, 255, 0], [255, 255, 85], [255, 255, 170], | |
[255, 0, 255], [255, 85, 255], [255, 170, 255], | |
[0, 255, 255], [85, 255, 255], [170, 255, 255]] | |
im = np.array(im) | |
vis_im = im.copy().astype(np.uint8) | |
vis_parsing_anno = parsing_anno.copy().astype(np.uint8) | |
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) | |
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 | |
num_of_class = np.max(vis_parsing_anno) | |
for pi in range(1, num_of_class + 1): | |
index = np.where(vis_parsing_anno == pi) | |
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] | |
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) | |
# print(vis_parsing_anno_color.shape, vis_im.shape) | |
vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0) | |
# Save result or not | |
if save_im: | |
cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno) | |
cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
return vis_im | |
def evaluate(cam_id=0, cp='model_final_diss.pth'): | |
n_classes = 19 | |
net = BiSeNet(n_classes=n_classes) | |
net.cuda() | |
save_pth = osp.join('res/cp', cp) | |
net.load_state_dict(torch.load(save_pth)) | |
net.eval() | |
to_tensor = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
]) | |
with torch.no_grad(): | |
mode = 1 | |
cam = cv2.VideoCapture(cam_id) | |
img_bg = np.zeros((int(cam.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cam.get(cv2.CAP_PROP_FRAME_WIDTH)), 3), dtype=np.uint8) | |
img_bg[:, :, 1] = 255 | |
lut = [None, np.zeros((256, 1), dtype=np.uint8), np.zeros((256, 1), dtype=np.uint8), np.zeros((256, 1), dtype=np.uint8), np.zeros((256, 1), dtype=np.uint8)] | |
# atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', | |
# 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] | |
for idx in range(1, 20): | |
lut[1][idx] = 1 | |
for idx in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17, 18]: | |
lut[2][idx] = 1 | |
for idx in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]: | |
lut[3][idx] = 1 | |
lut[4] = lut[3] | |
while True: | |
ret, img_cv = cam.read() | |
if not ret: | |
continue | |
if mode == 0: | |
cv2.imshow("camera", img_cv) | |
k = cv2.waitKey(1) | |
if k == 27: | |
break | |
if ord("0") <= k <= ord("9"): | |
mode = k - ord("0") | |
continue | |
# image = img.resize((512, 512), Image.BILINEAR) | |
image = cv2pil(cv2.resize(img_cv, (512, 512))) | |
img = torch.unsqueeze(to_tensor(image), 0).cuda() | |
out = net(img)[0] | |
parsing = out.squeeze(0).cpu().numpy().argmax(0).astype(np.uint8) | |
if mode in [1, 2, 3, 4]: | |
resized_parsing = cv2.resize(parsing, img_cv.shape[1::-1], interpolation=cv2.INTER_NEAREST) | |
if mode == 3: | |
img_cv = change_color(img_cv, resized_parsing, 16, [0, 0, 255]) | |
if mode == 4: | |
img_cv = change_color(img_cv, resized_parsing, 12, [255, 0, 0]) | |
img_cv = change_color(img_cv, resized_parsing, 13, [255, 0, 0]) | |
img_cv = change_color(img_cv, resized_parsing, 17, [0, 255, 255]) | |
p_lut = cv2.LUT(resized_parsing, lut[mode])[:, :, None] | |
img_draw = np.where(p_lut == 1, img_cv , img_bg) | |
if mode == 9: | |
img_draw = vis_parsing_maps(image, parsing, stride=1, save_im=False) | |
cv2.imshow("camera", img_draw) | |
k = cv2.waitKey(1) | |
if k == 27: | |
break | |
if ord("0") <= k <= ord("9"): | |
mode = k - ord("0") | |
if __name__ == "__main__": | |
evaluate(cam_id=0, cp='79999_iter.pth') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment