Skip to content

Instantly share code, notes, and snippets.

@Timmate
Created August 7, 2020 12:58
Show Gist options
  • Save Timmate/cc089034dac3f9fbbdcbb889aa2aaa0d to your computer and use it in GitHub Desktop.
Save Timmate/cc089034dac3f9fbbdcbb889aa2aaa0d to your computer and use it in GitHub Desktop.
import tensorflow as tf
import numpy as np
import PIL
from PIL import Image
import os
from scipy.io import loadmat, savemat
from preprocess_img import Preprocess
from load_data import *
from face_decoder import Face3D
import warnings
warnings.filterwarnings('ignore')
def rescale_mask(scaled_mask: np.array, transform_params: list) -> np.array:
"""
Uncrops and rescales (i.e., resizes) the given scaled and cropped mask back to the
resolution of the original image using the given transformation parameters.
"""
# Parse transform params.
original_image_width, original_image_height = transform_params[0:2]
s = transform_params[2] # the scaling parameter
s = (s / 102.0) ** -1
t = transform_params[3:] # some parameters for transformation
t = [elem.item() for elem in t]
# Repeat the computations for downscaling from preprocess_img.py/process_img() to get
# the parameters needed for uncropping and rescaling the mask.
# Get the width and height of the original image after downscaling.
scaled_image_width = np.array((original_image_width / s*102)).astype(np.int32)
scaled_image_height = np.array((original_image_height / s*102)).astype(np.int32)
scaled_mask_size = scaled_mask.shape[0] # e.g. 224, NB. a scaled and cropped mask always has a square shape
# Get an x or y coordinate for all sides (borders) of the mask.
left_side_x = (scaled_image_width/2 - scaled_mask_size/2 + float((t[0] - original_image_width/2)*102/s)).astype(np.int32)
right_side_x = left_side_x + scaled_mask_size
upper_side_y = (scaled_image_height/2 - scaled_mask_size/2 + float((original_image_height/2 - t[1])*102/s)).astype(np.int32)
lower_side_y = upper_side_y + scaled_mask_size
# Compute the number of black ('missing') pixels to add to all sides of the mask.
n_missing_pixels_left = left_side_x
n_missing_pixels_right = scaled_image_width - right_side_x
n_missing_pixels_top = upper_side_y
n_missing_pixels_bottom = scaled_image_height - lower_side_y
# Define np.arrays with the needed number of black pixels.
black_pixels_left = np.zeros(shape=(scaled_mask_size, n_missing_pixels_left, 3), dtype='uint8')
black_pixels_right = np.zeros(shape=(scaled_mask_size, n_missing_pixels_right, 3), dtype='uint8')
black_pixels_top = np.zeros(shape=(n_missing_pixels_top, scaled_image_width, 3), dtype='uint8')
black_pixels_bottom = np.zeros(shape=(n_missing_pixels_bottom, scaled_image_width, 3), dtype='uint8')
# Uncrop the mask by adding the black pixels to all sides of the scaled and cropped mask.
tmp = np.hstack([black_pixels_left, scaled_mask, black_pixels_right])
uncropped_mask = np.vstack([black_pixels_top, tmp, black_pixels_bottom])
# Rescale (i.e., resize) the uncropped mask back to the resolution of the original image.
uncropped_and_rescaled_mask = Image.fromarray(uncropped_mask).resize((original_image_width, original_image_height))
return uncropped_and_rescaled_mask
def load_graph(graph_filename):
with tf.gfile.GFile(graph_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
def demo():
INPUT_DIR = 'input'
OUTPUT_DIR = 'output'
RESCALED_MASKS_DIR = os.path.join(OUTPUT_DIR, 'rescaled_masks')
IMAGE_EXTENSIONS = ('jpg', 'jpeg', 'png')
# Create the directories if they do not exist yet.
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(RESCALED_MASKS_DIR, exist_ok=True)
# read BFM face model
# transfer original BFM model to our model
if not os.path.isfile('./BFM/BFM_model_front.mat'):
transferBFM09()
# read standard landmarks for preprocessing images
lm3D = load_lm3d()
batchsize = 1
# build reconstruction model
with tf.Graph().as_default() as graph, tf.device('/cpu:0'):
FaceReconstructor = Face3D()
images = tf.placeholder(name='input_imgs', shape=[batchsize, 224, 224, 3], dtype=tf.float32)
graph_def = load_graph('network/FaceReconModel.pb')
tf.import_graph_def(graph_def, name='resnet', input_map={'input_imgs:0': images})
# output coefficients of R-Net (dim = 257)
coeff = graph.get_tensor_by_name('resnet/coeff:0')
# reconstructing faces
FaceReconstructor.Reconstruction_Block(coeff, batchsize)
face_shape = FaceReconstructor.face_shape_t
face_texture = FaceReconstructor.face_texture
face_color = FaceReconstructor.face_color
landmarks_2d = FaceReconstructor.landmark_p
recon_img = FaceReconstructor.render_imgs
tri = FaceReconstructor.facemodel.face_buf
with tf.Session() as sess:
# Print some newlines to make the output more visible among warnings.
print('\n' * 3)
# Get the list of all files and filter only image files.
filenames = sorted(os.listdir(INPUT_DIR))
image_filenames_filter = lambda image_filename: image_filename.split('.')[-1].lower() in IMAGE_EXTENSIONS
image_filenames = filter(image_filenames_filter, filenames)
for image_filename in image_filenames:
print('reconstructing', image_filename, '...')
# load images and corresponding 5 facial landmarks
image_basename, image_extension = image_filename.split('.')
image_path = os.path.join(INPUT_DIR, image_filename)
landmarks_path = image_path.replace(image_extension, 'txt')
image_pillow, lm = load_img(image_path, landmarks_path)
# preprocess input image
input_img, lm_new, transform_params = Preprocess(image_pillow, lm, lm3D)
coeff_, face_shape_, face_texture_, face_color_, landmarks_2d_, recon_img_, tri_ = \
sess.run([coeff, face_shape, face_texture, face_color, landmarks_2d, recon_img, tri],
feed_dict={images: input_img})
# reshape outputs
input_img = np.squeeze(input_img)
face_shape_ = np.squeeze(face_shape_, (0))
face_texture_ = np.squeeze(face_texture_, (0))
face_color_ = np.squeeze(face_color_, (0))
landmarks_2d_ = np.squeeze(landmarks_2d_, (0)) # 68 landmarks
recon_img_ = np.squeeze(recon_img_, (0))
# ============
# Rescale (and uncrop) the mask (i.e., the reconstructed image) back to the resolution
# of the original image and save it.
mask_np = recon_img_[:, :, :3].astype('uint8') # drop the alpha channel and convert to `uint8`
rescaled_mask_pillow = rescale_mask(mask_np, transform_params)
rescaled_mask_save_path = os.path.join(RESCALED_MASKS_DIR, image_filename)
rescaled_mask_pillow.save(rescaled_mask_save_path) # don't use plt.imsave() for that as it outputs something weird
# ============
# Uncomment the lines below to save the output .mat and .obj files
# save_dict = {'recon_img': recon_img_, 'coeff': coeff,
# 'face_shape': face_shape_, 'face_texture': face_texture_,
# 'face_color': face_color_, 'lm_68p': landmarks_2d_
# }
# savemat(image_save_path_mat, save_dict)
# save_obj(image_save_path_obj, face_shape_, tri_,
# np.clip(face_color_, 0, 255) / 255) # 3D reconstruction face (in canonical view)
if __name__ == '__main__':
demo()
@zhanghm1995
Copy link

zhanghm1995 commented Mar 5, 2022

Thank for sharing this code. However,

I found that if the input face image is larger than 224x224, I could get correct image. But when my original input face image is 224x224, the n_missing_pixels_right and n_missing_pixels_bottom variables in rescale_mask function could be negative, in this case, this code will get error.

And I do some changes to make this code run well, but the results location couldn't match with original image very well. I don't know why.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment