Created
November 11, 2021 12:53
-
-
Save crouchggj/9935d5f1e7cca2e8f43d6c68237fd2fb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import cv2 | |
import time | |
import sys | |
import os | |
sys.path.append(os.path.join(os.getcwd(), "samples/python/common/")) | |
sys.path.append(os.path.join(os.getcwd(), "samples/python/common/atlas_utils")) | |
print('System init success.') | |
from atlas_utils.acl_resource import AclResource | |
from constants import * | |
from acl_model import Model | |
from atlas_utils.acl_image import AclImage | |
from atlas_utils.acl_resource import AclResource | |
from atlas_utils.acl_dvpp import Dvpp | |
MODEL_PATH = "./model/hifill.om" | |
MODEL_MATMUL_PATH = "./model/matmul/0_BatchMatMul_0_0_1_1_1024_1024_0_0_1_1_1024_27648_0_0_1_1_1024_27648.om" | |
# 模型输入图像的尺寸 | |
MODEL_WIDTH = 512 | |
MODEL_HEIGHT = 512 | |
INPUT_SIZE = 512 | |
# 将图像划分成32*32的网格,然后去计算注意力矩阵 | |
ATTENTION_SIZE = 32 | |
MULTIPLE = 6 | |
# 模型输入需要时float32类型的数据 | |
NPTYPE_FLOAT32 = np.float32 | |
def extract_image_patches(img, multiple): | |
# (6*512 6*512 3) | |
h, w, c = img.shape | |
# (512 6 512 6 3) | |
img = np.reshape(img, [h//multiple, multiple, w//multiple, multiple, c]) | |
# (512 512 6 6 3) | |
img = np.transpose(img, [0,2,1,3,4]) | |
return img | |
def resize_ave(img, MULTIPLE): | |
img = img.astype(NPTYPE_FLOAT32) | |
img_patches = extract_image_patches(img, MULTIPLE) | |
# (512 512 3) | |
img = np.mean(img_patches, axis=(2,3)) | |
return img | |
def pre_process(raw_img, raw_mask): | |
raw_mask = raw_mask.astype(NPTYPE_FLOAT32) / 255. | |
raw_img = raw_img.astype(NPTYPE_FLOAT32) | |
# 1 resize raw image & mask to desinated size,(6*512 6*512) | |
large_img = cv2.resize(raw_img, (MULTIPLE * INPUT_SIZE, MULTIPLE * INPUT_SIZE), interpolation = cv2.INTER_LINEAR) | |
large_mask = cv2.resize(raw_mask, (MULTIPLE * INPUT_SIZE, MULTIPLE * INPUT_SIZE), interpolation = cv2.INTER_NEAREST) | |
# 2 down-sample large image & mask to 512x512 | |
small_img = resize_ave(large_img, MULTIPLE) | |
small_mask = cv2.resize(raw_mask, (INPUT_SIZE, INPUT_SIZE), interpolation = cv2.INTER_NEAREST) | |
# set hole region to 1. and backgroun to 0. | |
small_mask = 1. - small_mask | |
small_img = np.ascontiguousarray(small_img) | |
# 3 get large image & mask 3072 X 3072 / 512x512 | |
return large_img, large_mask, small_img, small_mask | |
def extract_image_patches(img, multiple): | |
''' | |
细节图像按宽高等距切割 | |
1. 将细节图像按宽高等距切割成32 X 32个patch; | |
2. 每个patch大小为96 X 96(3072/32 = 96); | |
3. 考虑到一张细节图有3个channel, 所以每个patch有96 X 96 X 3 = 27648个像素; | |
4. 再将这32 X 32个patch按顺序排成一列一共有1024列; | |
5. 把每个patch所有像素按序拉成一行,共有27648行,由此组成了一个1024 X 27648的矩阵; | |
''' | |
h, w, c = img.shape | |
img = np.reshape(img, [h // multiple, multiple, w // multiple, multiple, c]) | |
img = np.transpose(img, [0, 2, 1, 3, 4]) | |
return img | |
def matmul_om_large( attention, residual): | |
# attention矩阵(大小为1024 X 1024)与reshape后的细节图矩阵(1024 X 27648)相乘 | |
attention_reshape = attention.reshape(1024, 1024) | |
residual_reshape = residual.reshape(1024, 96 * 96 * 3) | |
matmul_ret = matmul_model.execute([attention_reshape, residual_reshape]) | |
return matmul_ret[0].reshape(ATTENTION_SIZE, ATTENTION_SIZE, 3072 * 9) | |
def reconstruct_residual_from_patches(residual, multiple): | |
residual = np.reshape(residual, [ATTENTION_SIZE, ATTENTION_SIZE, multiple, multiple, 3]) | |
residual = np.transpose(residual, [0, 2, 1, 3, 4]) | |
return np.reshape(residual, [ATTENTION_SIZE * multiple, ATTENTION_SIZE * multiple, 3]) | |
def residual_aggregate(residual, attention): | |
# residual: 3072*3072*3,MULTIPLE * INPUT_SIZE//ATTENTION_SIZE: 96 | |
# 将大图也按照32*32划分成小块(每96*96个像素为一个小块) 结果:(32 32 96 96 3) | |
residual = extract_image_patches(residual, MULTIPLE * INPUT_SIZE//ATTENTION_SIZE) | |
# 结果:(1024,96*96*3),相当于是对于每个patch,将(96*96*3)展开,变成一个很宽的大矩阵! | |
residual = np.reshape(residual, [1, residual.shape[0] * residual.shape[1], -1]) | |
# 将注意力矩阵与残差相乘,得到残差的修复图(特别是mask内的部分:每个像素都被背景给填充上了) | |
residual = matmul_om_large(attention,residual) | |
residual = reconstruct_residual_from_patches(residual, MULTIPLE * INPUT_SIZE//ATTENTION_SIZE) | |
return residual | |
if __name__ == "__main__": | |
acl_resource = AclResource() | |
acl_resource.init() | |
print("Init Resource Done") | |
model = Model(MODEL_PATH) | |
matmul_model = Model(MODEL_MATMUL_PATH) | |
raw_img = cv2.imread('./image/15.jpeg') | |
raw_mask = cv2.imread('./image/15_mask.jpeg') | |
print(type(raw_img)) | |
print(raw_img.shape) | |
print(type(raw_mask)) | |
print(raw_mask.shape) | |
img_large, mask_large, img_512, mask_512 = pre_process(raw_img, raw_mask) | |
print("Preprocess success! ") | |
mask_512_hwc = mask_512[:,:,0:1] | |
mask_512_hwc = mask_512_hwc.transpose(2,0,1).copy() | |
print(mask_512_hwc.shape) | |
resultList = model.execute([img_512, mask_512_hwc]) | |
inpainted_512 = resultList[0] | |
attention = resultList[1] | |
mask_512_new = resultList[2] | |
print("Model execute finish !") | |
h, w, c = raw_img.shape | |
low_base = cv2.resize(inpainted_512[0].astype(NPTYPE_FLOAT32), | |
(INPUT_SIZE * MULTIPLE, INPUT_SIZE * MULTIPLE), interpolation = cv2.INTER_LINEAR) | |
low_large = cv2.resize(img_512.astype(NPTYPE_FLOAT32), | |
(INPUT_SIZE * MULTIPLE, INPUT_SIZE * MULTIPLE), interpolation = cv2.INTER_LINEAR) | |
residual = (img_large - low_large) * mask_large | |
residual = residual_aggregate(residual, attention[0]) | |
res_large = low_base + residual | |
res_large = np.clip(res_large, 0., 255.) | |
res_raw = cv2.resize(res_large, (w, h), interpolation = cv2.INTER_LINEAR) | |
mask = cv2.resize(mask_512_new[0].astype(NPTYPE_FLOAT32), (w, h), interpolation = cv2.INTER_LINEAR) | |
mask = np.expand_dims(mask, axis=2) | |
mask_img = (res_raw * mask).astype(np.uint8) | |
res_raw = res_raw * mask + raw_img * (1. - mask) | |
res_raw = res_raw.astype(np.uint8) | |
cv2.imwrite('./out.jpg',res_raw) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment