Last active
April 19, 2024 13:50
-
-
Save hinablue/abb2c5454c20ea5636b081bcda1ad938 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import math | |
from typing import NamedTuple, Tuple | |
import argparse | |
from prettytable import PrettyTable | |
from ultralytics import YOLO | |
class BucketManager: | |
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: | |
if max_size is not None: | |
if max_reso is not None: | |
assert max_size >= max_reso[0], "the max_size should be larger than the width of max_reso" | |
assert max_size >= max_reso[1], "the max_size should be larger than the height of max_reso" | |
if min_size is not None: | |
assert max_size >= min_size, "the max_size should be larger than the min_size" | |
self.no_upscale = no_upscale | |
if max_reso is None: | |
self.max_reso = None | |
self.max_area = None | |
else: | |
self.max_reso = max_reso | |
self.max_area = max_reso[0] * max_reso[1] | |
self.min_size = min_size | |
self.max_size = max_size | |
self.reso_steps = reso_steps | |
self.resos = [] | |
self.reso_to_id = {} | |
self.buckets = [] # 前処理時は (image_key, image, original size, crop left/top)、学習時は image_key | |
def add_image(self, reso, image_or_info): | |
bucket_id = self.reso_to_id[reso] | |
self.buckets[bucket_id].append(image_or_info) | |
def shuffle(self): | |
for bucket in self.buckets: | |
random.shuffle(bucket) | |
def sort(self): | |
# 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す | |
sorted_resos = self.resos.copy() | |
sorted_resos.sort() | |
sorted_buckets = [] | |
sorted_reso_to_id = {} | |
for i, reso in enumerate(sorted_resos): | |
bucket_id = self.reso_to_id[reso] | |
sorted_buckets.append(self.buckets[bucket_id]) | |
sorted_reso_to_id[reso] = i | |
self.resos = sorted_resos | |
self.buckets = sorted_buckets | |
self.reso_to_id = sorted_reso_to_id | |
def make_bucket_resolutions(self, max_reso, min_size=256, max_size=1024, divisible=64): | |
max_width, max_height = max_reso | |
max_area = max_width * max_height | |
resos = set() | |
width = int(math.sqrt(max_area) // divisible) * divisible | |
resos.add((width, width)) | |
width = min_size | |
while width <= max_size: | |
height = min(max_size, int((max_area // width) // divisible) * divisible) | |
if height >= min_size: | |
resos.add((width, height)) | |
resos.add((height, width)) | |
width += divisible | |
resos = list(resos) | |
resos.sort() | |
return resos | |
def make_buckets(self): | |
resos = self.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps) | |
self.set_predefined_resos(resos) | |
def set_predefined_resos(self, resos): | |
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく | |
self.predefined_resos = resos.copy() | |
self.predefined_resos_set = set(resos) | |
self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) | |
def add_if_new_reso(self, reso): | |
if reso not in self.reso_to_id: | |
bucket_id = len(self.resos) | |
self.reso_to_id[reso] = bucket_id | |
self.resos.append(reso) | |
self.buckets.append([]) | |
# print(reso, bucket_id, len(self.buckets)) | |
def round_to_steps(self, x): | |
x = int(x + 0.5) | |
return x - x % self.reso_steps | |
def select_bucket(self, image_width, image_height): | |
aspect_ratio = image_width / image_height | |
if not self.no_upscale: | |
# 拡大および縮小を行う | |
# 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する | |
reso = (image_width, image_height) | |
if reso in self.predefined_resos_set: | |
pass | |
else: | |
ar_errors = self.predefined_aspect_ratios - aspect_ratio | |
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの | |
reso = self.predefined_resos[predefined_bucket_id] | |
ar_reso = reso[0] / reso[1] | |
if aspect_ratio > ar_reso: # 横が長い→縦を合わせる | |
scale = reso[1] / image_height | |
else: | |
scale = reso[0] / image_width | |
resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) | |
# print(f"use predef, {image_width}, {image_height}, {reso}, {resized_size}") | |
else: | |
# 縮小のみを行う | |
if image_width * image_height > self.max_area: | |
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める | |
resized_width = math.sqrt(self.max_area * aspect_ratio) | |
resized_height = self.max_area / resized_width | |
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal" | |
# リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ | |
# 元のbucketingと同じロジック | |
b_width_rounded = self.round_to_steps(resized_width) | |
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio) | |
ar_width_rounded = b_width_rounded / b_height_in_wr | |
b_height_rounded = self.round_to_steps(resized_height) | |
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) | |
ar_height_rounded = b_width_in_hr / b_height_rounded | |
# print(b_width_rounded, b_height_in_wr, ar_width_rounded) | |
# print(b_width_in_hr, b_height_rounded, ar_height_rounded) | |
if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): | |
resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5)) | |
else: | |
resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded) | |
# print(resized_size) | |
else: | |
resized_size = (image_width, image_height) # リサイズは不要 | |
# 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) | |
bucket_width = resized_size[0] - resized_size[0] % self.reso_steps | |
bucket_height = resized_size[1] - resized_size[1] % self.reso_steps | |
# print(f"use arbitrary {image_width}, {image_height}, {resized_size}, {bucket_width}, {bucket_height}") | |
reso = (bucket_width, bucket_height) | |
self.add_if_new_reso(reso) | |
ar_error = (reso[0] / reso[1]) - aspect_ratio | |
return reso, resized_size, ar_error | |
@staticmethod | |
def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]): | |
# Stability AIの前処理に合わせてcrop left/topを計算する。crop rightはflipのaugmentationのために求める | |
# Calculate crop left/top according to the preprocessing of Stability AI. Crop right is calculated for flip augmentation. | |
bucket_ar = bucket_reso[0] / bucket_reso[1] | |
image_ar = image_size[0] / image_size[1] | |
if bucket_ar > image_ar: | |
# bucketのほうが横長→縦を合わせる | |
resized_width = bucket_reso[1] * image_ar | |
resized_height = bucket_reso[1] | |
else: | |
resized_width = bucket_reso[0] | |
resized_height = bucket_reso[0] / image_ar | |
crop_left = (bucket_reso[0] - resized_width) // 2 | |
crop_top = (bucket_reso[1] - resized_height) // 2 | |
crop_right = crop_left + resized_width | |
crop_bottom = crop_top + resized_height | |
return int(crop_left), int(crop_top), int(crop_right), int(crop_bottom) | |
def prepare_crop_with_face(self, image: Image, model_path: str, device: str): | |
yolo = YOLO(model_path) | |
pred = yolo.predict(source=image, conf=0.25, device=device) | |
bboxes = pred[0].boxes.xyxy.cpu().numpy() | |
if bboxes.size == 0: | |
return image | |
def get_crop_boundaries(center_x, center_y, width, height): | |
crop_top, crop_bottom = center_y, center_y | |
crop_left, crop_right = center_x, center_x | |
counter = [0, 0, 0, 0] | |
def update_boundary(boundary, index, step): | |
nonlocal counter | |
while True: | |
step = round(step) | |
boundary = (boundary + step) // step * step | |
counter[index] += 1 | |
if (step > 0 and boundary >= width) or (step < 0 and boundary <= 0): | |
break | |
return boundary | |
step_ratio = [1, 1] | |
if self.max_reso[0] > self.max_reso[1]: | |
step_ratio[0] = 1 / (self.max_reso[0] / self.max_reso[1]) | |
else: | |
step_ratio[1] = 1 / (self.max_reso[1] / self.max_reso[0]) | |
crop_top = update_boundary(crop_top, 0, -self.reso_steps * step_ratio[0]) | |
crop_bottom = update_boundary(crop_bottom, 1, self.reso_steps * step_ratio[0]) | |
crop_left = update_boundary(crop_left, 2, -self.reso_steps * step_ratio[1]) | |
crop_right = update_boundary(crop_right, 3, self.reso_steps * step_ratio[1]) | |
min_counter = min(counter) | |
crop_top = (center_y - self.reso_steps * min_counter * step_ratio[0]) // self.reso_steps * self.reso_steps | |
crop_bottom = (center_y + self.reso_steps * min_counter * step_ratio[0]) // self.reso_steps * self.reso_steps | |
crop_left = (center_x - self.reso_steps * min_counter * step_ratio[1]) // self.reso_steps * self.reso_steps | |
crop_right = (center_x + self.reso_steps * min_counter * step_ratio[1]) // self.reso_steps * self.reso_steps | |
if crop_top < 0: | |
crop_top = 0 | |
if crop_bottom > height: | |
crop_bottom = height | |
if crop_left < 0: | |
crop_left = 0 | |
if crop_right > width: | |
crop_right = width | |
return int(crop_top), int(crop_bottom), int(crop_left), int(crop_right) | |
height, width = image.shape[:2] | |
if len(bboxes.tolist()) > 1: | |
face_left, face_top, face_right, face_bottom = max(bboxes.tolist(), key=lambda x: (x[2] - x[0]) * (x[3] - x[1]) // self.reso_steps * self.reso_steps) | |
else: | |
face_left, face_top, face_right, face_bottom = bboxes.tolist()[0] | |
center_x = int(abs(face_left - face_right) / 2 + face_left) | |
center_y = int(abs(face_top - face_bottom) / 2 + face_top) | |
crop_top, crop_bottom, crop_left, crop_right = get_crop_boundaries(center_x, center_y, width, height) | |
return image[crop_top:crop_bottom, crop_left:crop_right] | |
def process_images(args, bucket_manager): | |
src_directory, dst_directory = args.src_dir, args.dest_dir | |
if not os.path.exists(dst_directory): | |
os.makedirs(dst_directory) | |
bucket_images = dict() | |
if args.debug or args.dry_run: | |
t = PrettyTable(['File', 'Reso', 'Original', 'Resized', 'Crop L/T', 'Crop R/B']) | |
for root, dirs, files in os.walk(src_directory): | |
# 保持原始的目錄結構 | |
rel_path = os.path.relpath(root, src_directory) | |
save_dir = os.path.join(dst_directory, rel_path) | |
if not os.path.exists(save_dir): | |
os.makedirs(save_dir) | |
for file in files: | |
if file.lower().endswith(('jpg', 'jpeg', 'png', 'webp')): | |
file_path = os.path.join(root, file) | |
try: | |
if os.path.exists(os.path.join(save_dir, file)): | |
if args.override: | |
print(f"Overriding {file}") | |
else: | |
print(f"Skipping {file}") | |
continue | |
image = cv2.imread(file_path) | |
orig_height, orig_width = image.shape[:2] | |
height, width = image.shape[:2] | |
if args.face_detect: | |
print(f"Detecting faces in {file}") | |
# 偵測臉部裁剪圖片(臉部置中,會造成分桶差異) | |
image = bucket_manager.prepare_crop_with_face(image, args.face_detect_model, args.device) | |
height, width = image.shape[:2] | |
# 縮圖分桶 | |
reso, resized_size, _ = bucket_manager.select_bucket(width, height) | |
image = cv2.resize(image, resized_size, interpolation = cv2.INTER_LANCZOS4) | |
crop_left, crop_top, crop_right, crop_bottom = BucketManager.get_crop_ltrb(reso, resized_size) | |
if args.debug or args.dry_run: | |
t.add_row([file, reso, (orig_width, orig_height), resized_size, (crop_left, crop_top), (crop_right, crop_bottom)]) | |
bkey = 'x'.join([str(r) for r in reso]) | |
if bkey not in bucket_images: | |
bucket_images[bkey] = [] | |
bucket_images[bkey].append(file) | |
if args.dry_run: | |
continue | |
# 裁剪圖片 | |
image = image[crop_top:crop_bottom, crop_left:crop_right].copy() | |
bucket_manager.add_image(reso, (file, image, (width, height), (crop_left, crop_top))) | |
# 保持原始的目錄結構 | |
cv2.imwrite(os.path.join(save_dir, file), image) | |
print(f"Processed, cropped and saved {file} to {save_dir}") | |
except Exception as e: | |
print(f"Error processing {file}: {e}") | |
if args.debug or args.dry_run: | |
print(t) | |
t = PrettyTable(['Buckets', 'Image Count', 'List']) | |
t.align['List'] = 'l' | |
for k, v in bucket_images.items(): | |
t.add_row([k, len(v), ', '.join(v)]) | |
print(t) | |
def setup_parser() -> argparse.ArgumentParser: | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--src_dir", | |
type=str, | |
default=None, | |
help="Path to source directory", | |
) | |
parser.add_argument( | |
"--dest_dir", | |
type=str, | |
default=None, | |
help="Path to destination directory", | |
) | |
parser.add_argument( | |
"--disable_no_upscale", | |
action="store_true", | |
help="Turn off no upscale, no_upscale default is True", | |
) | |
parser.add_argument( | |
"--max_reso_width", | |
type=int, default=512, | |
help="Max resolution width, default: 512", | |
) | |
parser.add_argument( | |
"--max_reso_height", | |
type=int, default=512, | |
help="Max resolution height, default: 512", | |
) | |
parser.add_argument( | |
"--min_size", | |
type=int, default=256, | |
help="Minimum size, default: 256", | |
) | |
parser.add_argument( | |
"--max_size", | |
type=int, default=1024, | |
help="Maximum size, default: 1024", | |
) | |
parser.add_argument( | |
"--reso_steps", | |
type=int, default=64, | |
help="Resolution steps, default: 64", | |
) | |
parser.add_argument( | |
"--face_detect", | |
action="store_true", | |
help="Detect faces in the image and crop to the face", | |
) | |
parser.add_argument( | |
"--face_detect_model", | |
type=str, default="./yolov8n-face.pt", | |
help="Path to the face detection model file, download from: https://github.com/akanametov/yolov8-face/releases/download/v0.0.0/yolov8n-face.pt", | |
) | |
parser.add_argument( | |
"--device", | |
type=str, default="cuda", | |
help="YOLO device, default: cuda", | |
) | |
parser.add_argument( | |
"--dry_run", | |
action="store_true", | |
help="Print out the image info without saving", | |
) | |
parser.add_argument( | |
"--debug", | |
action="store_true", | |
help="Print out the image info", | |
) | |
parser.add_argument( | |
"--override", | |
action="store_true", | |
help="Override existing file in the destination directory", | |
) | |
return parser | |
if __name__ == '__main__': | |
parser = setup_parser() | |
args = parser.parse_args() | |
if args.src_dir is None: | |
parser.print_help(sys.stderr) | |
exit(1) | |
if args.dest_dir is None: | |
parser.print_help(sys.stderr) | |
exit(1) | |
if args.face_detect and args.face_detect_model is None: | |
print(f"Face detection model file is required.\n") | |
exit(1) | |
if args.face_detect and not os.path.exists(args.face_detect_model): | |
print(f"Face detection model file {args.face_detect_model} not found.\n") | |
exit(1) | |
bucket_manager = BucketManager( | |
no_upscale=False if args.disable_no_upscale else True, | |
max_reso=(args.max_reso_width, args.max_reso_height), | |
min_size=args.min_size, | |
max_size=args.max_size, | |
reso_steps=args.reso_steps, | |
) | |
if args.disable_no_upscale: | |
bucket_manager.make_buckets() | |
process_images(args, bucket_manager) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment