Skip to content

Instantly share code, notes, and snippets.

@city96
Created April 18, 2024 18:01
Show Gist options
  • Save city96/2db92a207f3f93b1170dbd412688f812 to your computer and use it in GitHub Desktop.
Save city96/2db92a207f3f93b1170dbd412688f812 to your computer and use it in GitHub Desktop.
Test if PNG contains background hidden by alpha channel
# Find images with content hidden by alpha channel
# Ref https://github.com/kohya-ss/sd-scripts/issues/1269
# City96 | 2024
import os
import torch
import torchvision
from tqdm import tqdm
from torch.multiprocessing import Pool
# Path to source image dataset
SRC_PATH = r"Z:\booru\Test\images"
# Path to copy found images to. 'None' to disable
DST_PATH = r"E:\booru\Test\hidden"
# Path to log found files to, 'None' to disable
LOG_PATH = r"E:\booru\Test\hidden.csv"
# minimum background deviation to detect
MIN_STD = 0.1
# max images to check. 'None' to disable
LIMIT = None
def get_target_files(src_path, exts=[".png"], limit=LIMIT):
"""
Load list of files to check.
"""
paths = []
for root, _, files in os.walk(src_path):
for fname in files:
name, ext = os.path.splitext(fname)
if ext.lower() not in exts:
continue
paths.append(
os.path.join(root, fname)
)
if limit and len(paths) >= limit:
return paths
return paths
def read_image(img_path, dims=[4]):
"""
Load image if it has the required number of dimensions
"""
try:
img = torchvision.io.read_image(img_path)
except:
return None
else:
if img.shape[0] in dims:
return img
else:
return None
def check_image(img_path, dst_path=DST_PATH, threshold=MIN_STD):
"""
Check if background has content
"""
# load image & verify alpha
raw = read_image(img_path)
if raw is None:
return None
# separate layers
image = raw[:3] / 255.0
alpha = raw[3:].repeat(3, 1, 1)
# convert to hard edges and invert
alpha = ~(alpha > 0.0)
# apply to image and get background std
bgdev = torch.std(
image * alpha
)
# save if applicable
if bgdev > threshold:
if dst_path:
fname = os.path.basename(img_path)
path = os.path.join(dst_path, fname)
if not os.path.isfile(path):
torchvision.utils.save_image(image, path)
return (img_path, float(bgdev))
return None
if __name__ == "__main__":
print("Parsing paths")
paths = get_target_files(SRC_PATH)
os.makedirs(DST_PATH, exist_ok=True)
print (f"Checking {len(paths)} images")
# [check_image(x) for x in paths] # debug
pool = Pool(4)
imap = pool.imap(check_image, paths)
found = [x for x in tqdm(imap, total=len(paths)) if x]
pool.close()
print(f"Found {len(found)} images.")
if LOG_PATH:
with open(LOG_PATH, "w", encoding="utf-8") as f:
f.write(f"name,conf\n")
for name, conf in found:
f.write(f"{name},{round(conf,4)}\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment