Created
January 22, 2024 23:48
-
-
Save kohya-ss/06009889641d1f8995b56238b3b45778 to your computer and use it in GitHub Desktop.
クロマキー合成っぽいことをやる
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) | |
for i, t in enumerate(tqdm(timesteps)): | |
# ↓ ここから | |
# test: chroma key like composition | |
if latents.shape[0] == 4: | |
# run this script with batch size 4 | |
# sample prompt for ANIMAGINE XL V3.0: 2nd prompt doesn't have detailes, because it is used for making mask | |
# green surface of green screen --n color, artifact, object, shadow, frame --d 1 | |
# 1girl, serafuku, standing, cowboy shot, green background, masterpiece, best quality --n nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name --d 1 | |
# fine art of medieval street, daylight, crowds --n 1girl, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name --d 1 | |
# 1girl, serafuku, standing, white sailor color, red neckerchief, black hair, blunt bangs, @_@, cowboy shot, at medieval street, masterpiece, best quality --n nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name --d 1 | |
# batch index 0: mono color background | |
# batch index 1: foreground with background color | |
# batch index 2: background image | |
# batch index 3: foreground image | |
from torchvision.transforms import GaussianBlur | |
sample_back = latents[0] | |
sample_fore = latents[1] | |
abs_diff = torch.abs(sample_back - sample_fore) # 4,h,w | |
ch_sum = torch.sum(abs_diff, dim=0, keepdim=True) | |
print(f"[{i}] mean of diff: {ch_sum.mean()}") | |
foreground_mask = ch_sum | |
# foreground_mask = torch.cat([ch_sum, ch_sum, ch_sum], dim=0) # 3,h,w TODO 1chでいい | |
# apply gaussian filter | |
foreground_mask = foreground_mask.unsqueeze(0) | |
foreground_mask = GaussianBlur(kernel_size=3, sigma=1)(foreground_mask) | |
foreground_mask = foreground_mask.squeeze(0) | |
# convert to binary | |
# 0.5 は画像による、特に背景色プロンプトがどのくらい効くか / 0.5 depends on image, especially how strong background prompt is | |
threshold = torch.quantile(foreground_mask.float(), 0.5) | |
foreground_mask = (foreground_mask > threshold).float() | |
# inflate mask n times to fill holes | |
n = 3 | |
for _ in range(n): | |
foreground_mask = torch.nn.functional.max_pool2d(foreground_mask, kernel_size=3, stride=1, padding=1) | |
# deflate mask n times to shrink | |
for _ in range(n): | |
foreground_mask = torch.nn.functional.avg_pool2d(foreground_mask, kernel_size=3, stride=1, padding=1) | |
foreground_mask = (foreground_mask > 0.5).float() | |
# # save mask image | |
# mask_image = foreground_mask | |
# mask_image = mask_image.cpu().permute(1, 2, 0).float().numpy() | |
# mask_image = (mask_image * 255).round().astype("uint8") | |
# mask_image = np.concatenate([mask_image, mask_image, mask_image], axis=2) | |
# mask_image = Image.fromarray(mask_image) | |
# mask_image.save(f"logs\\mask_image_{i}.png") | |
# copy background image to foreground | |
sample_bg_img = latents[2] | |
sample_fg_img = latents[3] | |
foreground_mask = foreground_mask.repeat((4, 1, 1)) | |
sample_fg_img = sample_fg_img * foreground_mask + sample_bg_img * (1 - foreground_mask) | |
latents[3] = sample_fg_img | |
# ↑ ここまで | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment