Skip to content

Instantly share code, notes, and snippets.

@crouchggj
Created November 11, 2021 12:53
Show Gist options
  • Save crouchggj/9935d5f1e7cca2e8f43d6c68237fd2fb to your computer and use it in GitHub Desktop.
Save crouchggj/9935d5f1e7cca2e8f43d6c68237fd2fb to your computer and use it in GitHub Desktop.
import numpy as np
import cv2
import time
import sys
import os
sys.path.append(os.path.join(os.getcwd(), "samples/python/common/"))
sys.path.append(os.path.join(os.getcwd(), "samples/python/common/atlas_utils"))
print('System init success.')
from atlas_utils.acl_resource import AclResource
from constants import *
from acl_model import Model
from atlas_utils.acl_image import AclImage
from atlas_utils.acl_resource import AclResource
from atlas_utils.acl_dvpp import Dvpp
MODEL_PATH = "./model/hifill.om"
MODEL_MATMUL_PATH = "./model/matmul/0_BatchMatMul_0_0_1_1_1024_1024_0_0_1_1_1024_27648_0_0_1_1_1024_27648.om"
# 模型输入图像的尺寸
MODEL_WIDTH = 512
MODEL_HEIGHT = 512
INPUT_SIZE = 512
# 将图像划分成32*32的网格,然后去计算注意力矩阵
ATTENTION_SIZE = 32
MULTIPLE = 6
# 模型输入需要时float32类型的数据
NPTYPE_FLOAT32 = np.float32
def extract_image_patches(img, multiple):
# (6*512 6*512 3)
h, w, c = img.shape
# (512 6 512 6 3)
img = np.reshape(img, [h//multiple, multiple, w//multiple, multiple, c])
# (512 512 6 6 3)
img = np.transpose(img, [0,2,1,3,4])
return img
def resize_ave(img, MULTIPLE):
img = img.astype(NPTYPE_FLOAT32)
img_patches = extract_image_patches(img, MULTIPLE)
# (512 512 3)
img = np.mean(img_patches, axis=(2,3))
return img
def pre_process(raw_img, raw_mask):
raw_mask = raw_mask.astype(NPTYPE_FLOAT32) / 255.
raw_img = raw_img.astype(NPTYPE_FLOAT32)
# 1 resize raw image & mask to desinated size,(6*512 6*512)
large_img = cv2.resize(raw_img, (MULTIPLE * INPUT_SIZE, MULTIPLE * INPUT_SIZE), interpolation = cv2.INTER_LINEAR)
large_mask = cv2.resize(raw_mask, (MULTIPLE * INPUT_SIZE, MULTIPLE * INPUT_SIZE), interpolation = cv2.INTER_NEAREST)
# 2 down-sample large image & mask to 512x512
small_img = resize_ave(large_img, MULTIPLE)
small_mask = cv2.resize(raw_mask, (INPUT_SIZE, INPUT_SIZE), interpolation = cv2.INTER_NEAREST)
# set hole region to 1. and backgroun to 0.
small_mask = 1. - small_mask
small_img = np.ascontiguousarray(small_img)
# 3 get large image & mask 3072 X 3072 / 512x512
return large_img, large_mask, small_img, small_mask
def extract_image_patches(img, multiple):
'''
细节图像按宽高等距切割
1. 将细节图像按宽高等距切割成32 X 32个patch;
2. 每个patch大小为96 X 96(3072/32 = 96);
3. 考虑到一张细节图有3个channel, 所以每个patch有96 X 96 X 3 = 27648个像素; 
4. 再将这32 X 32个patch按顺序排成一列一共有1024列;
5. 把每个patch所有像素按序拉成一行,共有27648行,由此组成了一个1024 X 27648的矩阵;
'''
h, w, c = img.shape
img = np.reshape(img, [h // multiple, multiple, w // multiple, multiple, c])
img = np.transpose(img, [0, 2, 1, 3, 4])
return img
def matmul_om_large( attention, residual):
# attention矩阵(大小为1024 X 1024)与reshape后的细节图矩阵(1024 X 27648)相乘
attention_reshape = attention.reshape(1024, 1024)
residual_reshape = residual.reshape(1024, 96 * 96 * 3)
matmul_ret = matmul_model.execute([attention_reshape, residual_reshape])
return matmul_ret[0].reshape(ATTENTION_SIZE, ATTENTION_SIZE, 3072 * 9)
def reconstruct_residual_from_patches(residual, multiple):
residual = np.reshape(residual, [ATTENTION_SIZE, ATTENTION_SIZE, multiple, multiple, 3])
residual = np.transpose(residual, [0, 2, 1, 3, 4])
return np.reshape(residual, [ATTENTION_SIZE * multiple, ATTENTION_SIZE * multiple, 3])
def residual_aggregate(residual, attention):
# residual: 3072*3072*3,MULTIPLE * INPUT_SIZE//ATTENTION_SIZE: 96
# 将大图也按照32*32划分成小块(每96*96个像素为一个小块) 结果:(32 32 96 96 3)
residual = extract_image_patches(residual, MULTIPLE * INPUT_SIZE//ATTENTION_SIZE)
# 结果:(1024,96*96*3),相当于是对于每个patch,将(96*96*3)展开,变成一个很宽的大矩阵!
residual = np.reshape(residual, [1, residual.shape[0] * residual.shape[1], -1])
# 将注意力矩阵与残差相乘,得到残差的修复图(特别是mask内的部分:每个像素都被背景给填充上了)
residual = matmul_om_large(attention,residual)
residual = reconstruct_residual_from_patches(residual, MULTIPLE * INPUT_SIZE//ATTENTION_SIZE)
return residual
if __name__ == "__main__":
acl_resource = AclResource()
acl_resource.init()
print("Init Resource Done")
model = Model(MODEL_PATH)
matmul_model = Model(MODEL_MATMUL_PATH)
raw_img = cv2.imread('./image/15.jpeg')
raw_mask = cv2.imread('./image/15_mask.jpeg')
print(type(raw_img))
print(raw_img.shape)
print(type(raw_mask))
print(raw_mask.shape)
img_large, mask_large, img_512, mask_512 = pre_process(raw_img, raw_mask)
print("Preprocess success! ")
mask_512_hwc = mask_512[:,:,0:1]
mask_512_hwc = mask_512_hwc.transpose(2,0,1).copy()
print(mask_512_hwc.shape)
resultList = model.execute([img_512, mask_512_hwc])
inpainted_512 = resultList[0]
attention = resultList[1]
mask_512_new = resultList[2]
print("Model execute finish !")
h, w, c = raw_img.shape
low_base = cv2.resize(inpainted_512[0].astype(NPTYPE_FLOAT32),
(INPUT_SIZE * MULTIPLE, INPUT_SIZE * MULTIPLE), interpolation = cv2.INTER_LINEAR)
low_large = cv2.resize(img_512.astype(NPTYPE_FLOAT32),
(INPUT_SIZE * MULTIPLE, INPUT_SIZE * MULTIPLE), interpolation = cv2.INTER_LINEAR)
residual = (img_large - low_large) * mask_large
residual = residual_aggregate(residual, attention[0])
res_large = low_base + residual
res_large = np.clip(res_large, 0., 255.)
res_raw = cv2.resize(res_large, (w, h), interpolation = cv2.INTER_LINEAR)
mask = cv2.resize(mask_512_new[0].astype(NPTYPE_FLOAT32), (w, h), interpolation = cv2.INTER_LINEAR)
mask = np.expand_dims(mask, axis=2)
mask_img = (res_raw * mask).astype(np.uint8)
res_raw = res_raw * mask + raw_img * (1. - mask)
res_raw = res_raw.astype(np.uint8)
cv2.imwrite('./out.jpg',res_raw)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment