Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Created January 22, 2024 23:48
Show Gist options
  • Save kohya-ss/06009889641d1f8995b56238b3b45778 to your computer and use it in GitHub Desktop.
Save kohya-ss/06009889641d1f8995b56238b3b45778 to your computer and use it in GitHub Desktop.
クロマキー合成っぽいことをやる
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets)
for i, t in enumerate(tqdm(timesteps)):
# ↓ ここから
# test: chroma key like composition
if latents.shape[0] == 4:
# run this script with batch size 4
# sample prompt for ANIMAGINE XL V3.0: 2nd prompt doesn't have detailes, because it is used for making mask
# green surface of green screen --n color, artifact, object, shadow, frame --d 1
# 1girl, serafuku, standing, cowboy shot, green background, masterpiece, best quality --n nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name --d 1
# fine art of medieval street, daylight, crowds --n 1girl, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name --d 1
# 1girl, serafuku, standing, white sailor color, red neckerchief, black hair, blunt bangs, @_@, cowboy shot, at medieval street, masterpiece, best quality --n nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name --d 1
# batch index 0: mono color background
# batch index 1: foreground with background color
# batch index 2: background image
# batch index 3: foreground image
from torchvision.transforms import GaussianBlur
sample_back = latents[0]
sample_fore = latents[1]
abs_diff = torch.abs(sample_back - sample_fore) # 4,h,w
ch_sum = torch.sum(abs_diff, dim=0, keepdim=True)
print(f"[{i}] mean of diff: {ch_sum.mean()}")
foreground_mask = ch_sum
# foreground_mask = torch.cat([ch_sum, ch_sum, ch_sum], dim=0) # 3,h,w TODO 1chでいい
# apply gaussian filter
foreground_mask = foreground_mask.unsqueeze(0)
foreground_mask = GaussianBlur(kernel_size=3, sigma=1)(foreground_mask)
foreground_mask = foreground_mask.squeeze(0)
# convert to binary
# 0.5 は画像による、特に背景色プロンプトがどのくらい効くか / 0.5 depends on image, especially how strong background prompt is
threshold = torch.quantile(foreground_mask.float(), 0.5)
foreground_mask = (foreground_mask > threshold).float()
# inflate mask n times to fill holes
n = 3
for _ in range(n):
foreground_mask = torch.nn.functional.max_pool2d(foreground_mask, kernel_size=3, stride=1, padding=1)
# deflate mask n times to shrink
for _ in range(n):
foreground_mask = torch.nn.functional.avg_pool2d(foreground_mask, kernel_size=3, stride=1, padding=1)
foreground_mask = (foreground_mask > 0.5).float()
# # save mask image
# mask_image = foreground_mask
# mask_image = mask_image.cpu().permute(1, 2, 0).float().numpy()
# mask_image = (mask_image * 255).round().astype("uint8")
# mask_image = np.concatenate([mask_image, mask_image, mask_image], axis=2)
# mask_image = Image.fromarray(mask_image)
# mask_image.save(f"logs\\mask_image_{i}.png")
# copy background image to foreground
sample_bg_img = latents[2]
sample_fg_img = latents[3]
foreground_mask = foreground_mask.repeat((4, 1, 1))
sample_fg_img = sample_fg_img * foreground_mask + sample_bg_img * (1 - foreground_mask)
latents[3] = sample_fg_img
# ↑ ここまで
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment