Created
October 2, 2023 08:57
-
-
Save recoilme/a6f24ac83e89a0cecea9ba662d7f5e1a to your computer and use it in GitHub Desktop.
Create mask from image based on face coords
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import numpy as np | |
import dlib | |
import mediapipe as mp | |
import argparse | |
import torch | |
from diffusers import StableDiffusionImg2ImgPipeline | |
from diffusers import EulerAncestralDiscreteScheduler | |
from typing import Any, Callable, Dict, List, Optional, Union | |
import PIL | |
class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): | |
debug_save = True | |
def _make_latent_mask(self, latents, mask): | |
if mask is not None: | |
latent_mask = [] | |
if not isinstance(mask, list): | |
tmp_mask = [mask] | |
else: | |
tmp_mask = mask | |
_, l_channels, l_height, l_width = latents.shape | |
for m in tmp_mask: | |
if not isinstance(m, PIL.Image.Image): | |
if len(m.shape) == 2: | |
m = m[..., np.newaxis] | |
if m.max() > 1: | |
m = m / 255.0 | |
m = self.image_processor.numpy_to_pil(m)[0] | |
if m.mode != "L": | |
m = m.convert("L") | |
resized = self.image_processor.resize(m, l_height, l_width) | |
if self.debug_save: | |
resized.save("latent_mask.png") | |
latent_mask.append(np.repeat(np.array(resized)[np.newaxis, :, :], l_channels, axis=0)) | |
latent_mask = torch.as_tensor(np.stack(latent_mask)).to(latents) | |
latent_mask = latent_mask / latent_mask.max() | |
return latent_mask | |
@torch.no_grad() | |
def __call__( | |
self, | |
prompt: Union[str, List[str]] = None, | |
image: Union[ | |
torch.FloatTensor, | |
PIL.Image.Image, | |
np.ndarray, | |
List[torch.FloatTensor], | |
List[PIL.Image.Image], | |
List[np.ndarray], | |
] = None, | |
strength: float = 0.8, | |
num_inference_steps: Optional[int] = 50, | |
guidance_scale: Optional[float] = 7.5, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
num_images_per_prompt: Optional[int] = 1, | |
eta: Optional[float] = 0.0, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
callback_steps: int = 1, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
mask: Union[ | |
torch.FloatTensor, | |
PIL.Image.Image, | |
np.ndarray, | |
List[torch.FloatTensor], | |
List[PIL.Image.Image], | |
List[np.ndarray], | |
] = None, | |
): | |
r""" | |
The call function to the pipeline for generation. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. | |
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): | |
`Image` or tensor representing an image batch to be used as the starting point. Can also accept image | |
latents as `image`, but if passing latents directly it is not encoded again. | |
strength (`float`, *optional*, defaults to 0.8): | |
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a | |
starting point and more noise is added the higher the `strength`. The number of denoising steps depends | |
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising | |
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 | |
essentially ignores `image`. | |
num_inference_steps (`int`, *optional*, defaults to 50): | |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
expense of slower inference. This parameter is modulated by `strength`. | |
guidance_scale (`float`, *optional*, defaults to 7.5): | |
A higher guidance scale value encourages the model to generate images closely linked to the text | |
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. | |
negative_prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide what to not include in image generation. If not defined, you need to | |
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). | |
num_images_per_prompt (`int`, *optional*, defaults to 1): | |
The number of images to generate per prompt. | |
eta (`float`, *optional*, defaults to 0.0): | |
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies | |
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. | |
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make | |
generation deterministic. | |
prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not | |
provided, text embeddings are generated from the `prompt` input argument. | |
negative_prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If | |
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. | |
output_type (`str`, *optional*, defaults to `"pil"`): | |
The output format of the generated image. Choose between `PIL.Image` or `np.array`. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | |
plain tuple. | |
callback (`Callable`, *optional*): | |
A function that calls every `callback_steps` steps during inference. The function is called with the | |
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. | |
callback_steps (`int`, *optional*, defaults to 1): | |
The frequency at which the `callback` function is called. If not specified, the callback is called at | |
every step. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in | |
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`, *optional*): | |
A mask with non-zero elements for the area to be inpainted. If not specified, no mask is applied. | |
Examples: | |
Returns: | |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | |
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, | |
otherwise a `tuple` is returned where the first element is a list with the generated images and the | |
second element is a list of `bool`s indicating whether the corresponding generated image contains | |
"not-safe-for-work" (nsfw) content. | |
""" | |
# code adapted from parent class StableDiffusionImg2ImgPipeline | |
# 0. Check inputs. Raise error if not correct | |
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) | |
# 1. Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
device = self._execution_device | |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
# corresponds to doing no classifier free guidance. | |
do_classifier_free_guidance = guidance_scale > 1.0 | |
# 2. Encode input prompt | |
text_encoder_lora_scale = ( | |
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None | |
) | |
prompt_embeds = self._encode_prompt( | |
prompt, | |
device, | |
num_images_per_prompt, | |
do_classifier_free_guidance, | |
negative_prompt, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
lora_scale=text_encoder_lora_scale, | |
) | |
# 3. Preprocess image | |
image = self.image_processor.preprocess(image) | |
# 4. set timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) | |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | |
# 5. Prepare latent variables | |
# it is sampled from the latent distribution of the VAE | |
latents = self.prepare_latents( | |
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator | |
) | |
# mean of the latent distribution | |
init_latents = [ | |
self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean | |
for i in range(batch_size) | |
] | |
init_latents = torch.cat(init_latents, dim=0) | |
# 6. create latent mask | |
latent_mask = self._make_latent_mask(latents, mask) | |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
# 8. Denoising loop | |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
# predict the noise residual | |
noise_pred = self.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=prompt_embeds, | |
cross_attention_kwargs=cross_attention_kwargs, | |
return_dict=False, | |
)[0] | |
# perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
if latent_mask is not None: | |
#latents = torch.lerp(init_latents * self.vae.config.scaling_factor, latents, latent_mask) | |
latents = latents * latent_mask + (init_latents * self.vae.config.scaling_factor)*(1 - latent_mask) | |
#noise_pred = torch.lerp(torch.zeros_like(noise_pred), noise_pred, latent_mask) | |
noise_pred = noise_pred * latent_mask + (torch.zeros_like(noise_pred))*(1-latent_mask) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
progress_bar.update() | |
if callback is not None and i % callback_steps == 0: | |
callback(i, t, latents) | |
if not output_type == "latent": | |
print(self.vae.config.scaling_factor) | |
scaled = latents / self.vae.config.scaling_factor | |
print("scaled",scaled) | |
if latent_mask is not None: | |
scaled = latents / self.vae.config.scaling_factor * latent_mask + init_latents * (1 - latent_mask) | |
#scaled = torch.lerp(init_latents, scaled, latent_mask) | |
print("scaled",scaled, "latent_mask",latent_mask) | |
image = self.vae.decode(scaled, return_dict=False)[0] | |
if self.debug_save: | |
image_gen = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | |
image_gen = self.image_processor.postprocess(image_gen, output_type=output_type, do_denormalize=[True]) | |
image_gen[0].save("from_latent.png") | |
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | |
else: | |
image = latents | |
has_nsfw_concept = None | |
if has_nsfw_concept is None: | |
do_denormalize = [True] * image.shape[0] | |
else: | |
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] | |
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) | |
# Offload last model to CPU | |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | |
self.final_offload_hook.offload() | |
if not return_dict: | |
return (image, has_nsfw_concept) | |
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | |
def extract_faces(image): | |
# Convert BGR image to RGB image | |
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
detector = dlib.get_frontal_face_detector() | |
dets = detector(rgb_image, 2) | |
center_faces = [] | |
TEXT_COLOR = (0, 255, 0) | |
for i, d in enumerate(dets): | |
center_faces.append(([d.left()+int((d.right()-d.left())/2), int(d.top()+(d.bottom()-d.top())/2)])) | |
start_point = d.left(), d.top() | |
end_point = d.right(), d.bottom() | |
cv2.rectangle(rgb_image, start_point, end_point, TEXT_COLOR, 3) | |
cv2.imwrite("result_image_face2.png", rgb_image) | |
# second attempt | |
BaseOptions = mp.tasks.BaseOptions | |
FaceDetector = mp.tasks.vision.FaceDetector | |
FaceDetectorOptions = mp.tasks.vision.FaceDetectorOptions | |
VisionRunningMode = mp.tasks.vision.RunningMode | |
# Create a face detector instance with the image mode: | |
options = FaceDetectorOptions( | |
base_options=BaseOptions(model_asset_path='blaze_face_short_range.tflite'), | |
running_mode=VisionRunningMode.IMAGE) | |
with FaceDetector.create_from_options(options) as detector: | |
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image) | |
# The face detector must be created with the image mode. | |
face_detector_result = detector.detect(mp_image) | |
#print(face_detector_result) | |
for detection in face_detector_result.detections: | |
bbox = detection.bounding_box | |
#start_point = bbox.origin_x, bbox.origin_y | |
#end_point = bbox.origin_x + bbox.width, bbox.origin_y + bbox.height | |
#cv2.rectangle(image, start_point, end_point, TEXT_COLOR, 3) | |
#cv2.imwrite("result_image_face.png", image) | |
#print(start_point,end_point) | |
center_faces.append(([int(bbox.origin_x+bbox.width/2), int(bbox.origin_y+bbox.height/2)])) | |
#print("1",center_faces) | |
return center_faces | |
def find_normalized_distances(image, face_coordinates, low, max): | |
height, width, _ = image.shape | |
distances = np.zeros((height, width)) | |
for (center_x, center_y) in face_coordinates: | |
for y1 in range(height): | |
for x1 in range(width): | |
if distances[y1, x1]!=0.0: | |
distances[y1, x1] = min(distances[y1, x1],np.sqrt((x1 - center_x) ** 2 + (y1 - center_y) ** 2)) | |
else: | |
distances[y1, x1] = np.sqrt((x1 - center_x) ** 2 + (y1 - center_y) ** 2) | |
# Normalize the distance values between min and max | |
OldRange = (np.max(distances) - np.min(distances)) | |
NewRange = (max - low) | |
if np.max(distances) == 0.0: | |
return distances | |
normalized_distances = (((distances - np.min(distances)) * NewRange) / OldRange) + low | |
return normalized_distances | |
def merge_masks_with_image(base_image, normalized_distances): | |
b_channel, g_channel, r_channel = cv2.split(base_image) | |
alpha_gradient = ((1- normalized_distances) * 255).astype(np.uint8) | |
img_BGRA = cv2.merge((b_channel, g_channel, r_channel, alpha_gradient)) | |
return img_BGRA | |
def get_depth(image): | |
# Load the pretrained MIDAS model | |
model = torch.hub.load("intel-isl/MiDaS", "MiDaS") | |
# Set the model to evaluation mode | |
model.eval() | |
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
rgb_image = cv2.resize(rgb_image, (384, 384)) # Resize to match the input size of MIDAS | |
# Convert the image to torch tensor | |
image_tensor = torch.from_numpy(rgb_image.astype(np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0) | |
# Run the image through MIDAS to get the depth map | |
with torch.no_grad(): | |
depth_map = model.forward(image_tensor) | |
# Normalize the depth map to range [0, 1] | |
normalized_depth = (depth_map - torch.min(depth_map)) / (torch.max(depth_map) - torch.min(depth_map)) | |
# convert to alpha channel square image | |
sq_img = np.zeros((384, 384, 4), dtype=np.uint8) | |
sq_img[:, :, 3] = (1 - normalized_depth.squeeze().cpu().numpy())*255 | |
height, width, _ = image.shape | |
# resize 2 original | |
sq_img = cv2.resize(sq_img, (width, height)) | |
cv2.imwrite("result_depth.png", sq_img) | |
# extract the alpha cannel | |
#alpha_channel = sq_img | |
# Ensure the alpha channel is a plain NumPy array | |
alpha_array = np.array(sq_img[:, :, 3], dtype=np.uint8) | |
#print(alpha_array) | |
return alpha_array | |
# Convert the normalized depth map to transparency map | |
#transparency_map = 1 - normalized_depth.squeeze().cpu().numpy() | |
#alpha_gradient = (((normalized_depth.squeeze().cpu().numpy())) * 255).astype(np.uint8) | |
#height, width, _ = image.shape | |
#mask = np.zeros((height, width, 4), dtype=np.uint8) | |
#cv2.imwrite("result_depth.png", cv2.merge((mask[:, :, 0:3], alpha_gradient))) | |
# Overlay the transparency map onto the original image | |
#height, width, _ = image.shape | |
#mask = np.zeros((height, width, 4), dtype=np.uint8) | |
#plain_array = normalized_depth.squeeze().cpu().numpy() | |
#transparent_image = np.concatenate([rgb_image, np.expand_dims(transparency_map, axis=2)], axis=2) | |
#transparent_image = cv2.resize(transparent_image, (width,height)) | |
# Save the transparent image | |
#cv2.imwrite("result_depth.png", (transparent_image * 255).astype(np.uint8)) | |
def main(image_file): | |
# Read image | |
image = cv2.imread(image_file) | |
maxwh = 640 | |
f1 = maxwh / image.shape[1] | |
f2 = maxwh / image.shape[0] | |
f = min(f1, f2) # resizing factor | |
dim = (int(image.shape[1] * f), int(image.shape[0] * f)) | |
image = cv2.resize(image, dim) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# get depth mask | |
get_depth(image) | |
# Extract faces | |
face_coordinates = extract_faces(image) | |
print(face_coordinates) | |
# Create transparency mask | |
normalized_distances = find_normalized_distances(image, face_coordinates, 0.4, 1.0) | |
# Merge mask with image | |
image_with_alpha = merge_masks_with_image(image, normalized_distances) | |
# Save resulting mask and image | |
height, width, _ = image.shape | |
print(height, width) | |
mask = np.zeros((height, width, 4), dtype=np.uint8) | |
alpha_gradient = ((normalized_distances) * 255).astype(np.uint8) | |
cv2.imwrite("result_mask.png", cv2.merge((mask[:, :, 0:3], alpha_gradient))) | |
cv2.imwrite("result_image.png", image_with_alpha) | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
init_image = PIL.Image.fromarray(image_rgb) | |
#init_image = PIL.Image.open("image3.jpg").convert("RGB") | |
#mask_image = PIL.Image.open("result_mask2.jpg").convert("RGBA") | |
pipe = MaskedStableDiffusionImg2ImgPipeline.from_pretrained("colorful-v3-1") | |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
pipe = pipe.to("mps") | |
pipe.safety_checker = None | |
# Recommended if your computer has < 64 GB of RAM | |
#pipe.enable_attention_slicing() | |
generator = torch.Generator(device="cpu").manual_seed(0) | |
prompt = "portrait man in headphones, Balenciaga style" | |
negative_prompt = "deformed, disfigured, poorly drawn, bad anatomy,wrong anatomy, limb, mutated, blurry" | |
# First-time "warmup" pass if PyTorch version is 1.13 (see explanation above) | |
#_ = pipe(prompt,generator=generator,image=init_image, mask=mask_image, width=384,height=640, num_inference_steps=1) | |
# Results match those from the CPU device after the warmup pass. | |
# width=384,height=640 | |
image = pipe(prompt=prompt,negative_prompt=negative_prompt, generator=generator, image=init_image, mask=normalized_distances, num_inference_steps=1).images[0] | |
image.save("00.png") | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Image processing script') | |
parser.add_argument('image', type=str, help='Name of the image file') | |
args = parser.parse_args() | |
main(args.image) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment