Skip to content

Instantly share code, notes, and snippets.

@hinablue
Last active April 19, 2024 13:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hinablue/abb2c5454c20ea5636b081bcda1ad938 to your computer and use it in GitHub Desktop.
Save hinablue/abb2c5454c20ea5636b081bcda1ad938 to your computer and use it in GitHub Desktop.
import os
import sys
from PIL import Image
import cv2
import numpy as np
import math
from typing import NamedTuple, Tuple
import argparse
from prettytable import PrettyTable
from ultralytics import YOLO
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
if max_size is not None:
if max_reso is not None:
assert max_size >= max_reso[0], "the max_size should be larger than the width of max_reso"
assert max_size >= max_reso[1], "the max_size should be larger than the height of max_reso"
if min_size is not None:
assert max_size >= min_size, "the max_size should be larger than the min_size"
self.no_upscale = no_upscale
if max_reso is None:
self.max_reso = None
self.max_area = None
else:
self.max_reso = max_reso
self.max_area = max_reso[0] * max_reso[1]
self.min_size = min_size
self.max_size = max_size
self.reso_steps = reso_steps
self.resos = []
self.reso_to_id = {}
self.buckets = [] # 前処理時は (image_key, image, original size, crop left/top)、学習時は image_key
def add_image(self, reso, image_or_info):
bucket_id = self.reso_to_id[reso]
self.buckets[bucket_id].append(image_or_info)
def shuffle(self):
for bucket in self.buckets:
random.shuffle(bucket)
def sort(self):
# 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す
sorted_resos = self.resos.copy()
sorted_resos.sort()
sorted_buckets = []
sorted_reso_to_id = {}
for i, reso in enumerate(sorted_resos):
bucket_id = self.reso_to_id[reso]
sorted_buckets.append(self.buckets[bucket_id])
sorted_reso_to_id[reso] = i
self.resos = sorted_resos
self.buckets = sorted_buckets
self.reso_to_id = sorted_reso_to_id
def make_bucket_resolutions(self, max_reso, min_size=256, max_size=1024, divisible=64):
max_width, max_height = max_reso
max_area = max_width * max_height
resos = set()
width = int(math.sqrt(max_area) // divisible) * divisible
resos.add((width, width))
width = min_size
while width <= max_size:
height = min(max_size, int((max_area // width) // divisible) * divisible)
if height >= min_size:
resos.add((width, height))
resos.add((height, width))
width += divisible
resos = list(resos)
resos.sort()
return resos
def make_buckets(self):
resos = self.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
self.set_predefined_resos(resos)
def set_predefined_resos(self, resos):
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
self.predefined_resos = resos.copy()
self.predefined_resos_set = set(resos)
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
def add_if_new_reso(self, reso):
if reso not in self.reso_to_id:
bucket_id = len(self.resos)
self.reso_to_id[reso] = bucket_id
self.resos.append(reso)
self.buckets.append([])
# print(reso, bucket_id, len(self.buckets))
def round_to_steps(self, x):
x = int(x + 0.5)
return x - x % self.reso_steps
def select_bucket(self, image_width, image_height):
aspect_ratio = image_width / image_height
if not self.no_upscale:
# 拡大および縮小を行う
# 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
reso = (image_width, image_height)
if reso in self.predefined_resos_set:
pass
else:
ar_errors = self.predefined_aspect_ratios - aspect_ratio
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
reso = self.predefined_resos[predefined_bucket_id]
ar_reso = reso[0] / reso[1]
if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
scale = reso[1] / image_height
else:
scale = reso[0] / image_width
resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5))
# print(f"use predef, {image_width}, {image_height}, {reso}, {resized_size}")
else:
# 縮小のみを行う
if image_width * image_height > self.max_area:
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
resized_width = math.sqrt(self.max_area * aspect_ratio)
resized_height = self.max_area / resized_width
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
# リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
# 元のbucketingと同じロジック
b_width_rounded = self.round_to_steps(resized_width)
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
ar_width_rounded = b_width_rounded / b_height_in_wr
b_height_rounded = self.round_to_steps(resized_height)
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
ar_height_rounded = b_width_in_hr / b_height_rounded
# print(b_width_rounded, b_height_in_wr, ar_width_rounded)
# print(b_width_in_hr, b_height_rounded, ar_height_rounded)
if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5))
else:
resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded)
# print(resized_size)
else:
resized_size = (image_width, image_height) # リサイズは不要
# 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
# print(f"use arbitrary {image_width}, {image_height}, {resized_size}, {bucket_width}, {bucket_height}")
reso = (bucket_width, bucket_height)
self.add_if_new_reso(reso)
ar_error = (reso[0] / reso[1]) - aspect_ratio
return reso, resized_size, ar_error
@staticmethod
def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]):
# Stability AIの前処理に合わせてcrop left/topを計算する。crop rightはflipのaugmentationのために求める
# Calculate crop left/top according to the preprocessing of Stability AI. Crop right is calculated for flip augmentation.
bucket_ar = bucket_reso[0] / bucket_reso[1]
image_ar = image_size[0] / image_size[1]
if bucket_ar > image_ar:
# bucketのほうが横長→縦を合わせる
resized_width = bucket_reso[1] * image_ar
resized_height = bucket_reso[1]
else:
resized_width = bucket_reso[0]
resized_height = bucket_reso[0] / image_ar
crop_left = (bucket_reso[0] - resized_width) // 2
crop_top = (bucket_reso[1] - resized_height) // 2
crop_right = crop_left + resized_width
crop_bottom = crop_top + resized_height
return int(crop_left), int(crop_top), int(crop_right), int(crop_bottom)
def prepare_crop_with_face(self, image: Image, model_path: str, device: str):
yolo = YOLO(model_path)
pred = yolo.predict(source=image, conf=0.25, device=device)
bboxes = pred[0].boxes.xyxy.cpu().numpy()
if bboxes.size == 0:
return image
def get_crop_boundaries(center_x, center_y, width, height):
crop_top, crop_bottom = center_y, center_y
crop_left, crop_right = center_x, center_x
counter = [0, 0, 0, 0]
def update_boundary(boundary, index, step):
nonlocal counter
while True:
step = round(step)
boundary = (boundary + step) // step * step
counter[index] += 1
if (step > 0 and boundary >= width) or (step < 0 and boundary <= 0):
break
return boundary
step_ratio = [1, 1]
if self.max_reso[0] > self.max_reso[1]:
step_ratio[0] = 1 / (self.max_reso[0] / self.max_reso[1])
else:
step_ratio[1] = 1 / (self.max_reso[1] / self.max_reso[0])
crop_top = update_boundary(crop_top, 0, -self.reso_steps * step_ratio[0])
crop_bottom = update_boundary(crop_bottom, 1, self.reso_steps * step_ratio[0])
crop_left = update_boundary(crop_left, 2, -self.reso_steps * step_ratio[1])
crop_right = update_boundary(crop_right, 3, self.reso_steps * step_ratio[1])
min_counter = min(counter)
crop_top = (center_y - self.reso_steps * min_counter * step_ratio[0]) // self.reso_steps * self.reso_steps
crop_bottom = (center_y + self.reso_steps * min_counter * step_ratio[0]) // self.reso_steps * self.reso_steps
crop_left = (center_x - self.reso_steps * min_counter * step_ratio[1]) // self.reso_steps * self.reso_steps
crop_right = (center_x + self.reso_steps * min_counter * step_ratio[1]) // self.reso_steps * self.reso_steps
if crop_top < 0:
crop_top = 0
if crop_bottom > height:
crop_bottom = height
if crop_left < 0:
crop_left = 0
if crop_right > width:
crop_right = width
return int(crop_top), int(crop_bottom), int(crop_left), int(crop_right)
height, width = image.shape[:2]
if len(bboxes.tolist()) > 1:
face_left, face_top, face_right, face_bottom = max(bboxes.tolist(), key=lambda x: (x[2] - x[0]) * (x[3] - x[1]) // self.reso_steps * self.reso_steps)
else:
face_left, face_top, face_right, face_bottom = bboxes.tolist()[0]
center_x = int(abs(face_left - face_right) / 2 + face_left)
center_y = int(abs(face_top - face_bottom) / 2 + face_top)
crop_top, crop_bottom, crop_left, crop_right = get_crop_boundaries(center_x, center_y, width, height)
return image[crop_top:crop_bottom, crop_left:crop_right]
def process_images(args, bucket_manager):
src_directory, dst_directory = args.src_dir, args.dest_dir
if not os.path.exists(dst_directory):
os.makedirs(dst_directory)
bucket_images = dict()
if args.debug or args.dry_run:
t = PrettyTable(['File', 'Reso', 'Original', 'Resized', 'Crop L/T', 'Crop R/B'])
for root, dirs, files in os.walk(src_directory):
# 保持原始的目錄結構
rel_path = os.path.relpath(root, src_directory)
save_dir = os.path.join(dst_directory, rel_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for file in files:
if file.lower().endswith(('jpg', 'jpeg', 'png', 'webp')):
file_path = os.path.join(root, file)
try:
if os.path.exists(os.path.join(save_dir, file)):
if args.override:
print(f"Overriding {file}")
else:
print(f"Skipping {file}")
continue
image = cv2.imread(file_path)
orig_height, orig_width = image.shape[:2]
height, width = image.shape[:2]
if args.face_detect:
print(f"Detecting faces in {file}")
# 偵測臉部裁剪圖片(臉部置中,會造成分桶差異)
image = bucket_manager.prepare_crop_with_face(image, args.face_detect_model, args.device)
height, width = image.shape[:2]
# 縮圖分桶
reso, resized_size, _ = bucket_manager.select_bucket(width, height)
image = cv2.resize(image, resized_size, interpolation = cv2.INTER_LANCZOS4)
crop_left, crop_top, crop_right, crop_bottom = BucketManager.get_crop_ltrb(reso, resized_size)
if args.debug or args.dry_run:
t.add_row([file, reso, (orig_width, orig_height), resized_size, (crop_left, crop_top), (crop_right, crop_bottom)])
bkey = 'x'.join([str(r) for r in reso])
if bkey not in bucket_images:
bucket_images[bkey] = []
bucket_images[bkey].append(file)
if args.dry_run:
continue
# 裁剪圖片
image = image[crop_top:crop_bottom, crop_left:crop_right].copy()
bucket_manager.add_image(reso, (file, image, (width, height), (crop_left, crop_top)))
# 保持原始的目錄結構
cv2.imwrite(os.path.join(save_dir, file), image)
print(f"Processed, cropped and saved {file} to {save_dir}")
except Exception as e:
print(f"Error processing {file}: {e}")
if args.debug or args.dry_run:
print(t)
t = PrettyTable(['Buckets', 'Image Count', 'List'])
t.align['List'] = 'l'
for k, v in bucket_images.items():
t.add_row([k, len(v), ', '.join(v)])
print(t)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--src_dir",
type=str,
default=None,
help="Path to source directory",
)
parser.add_argument(
"--dest_dir",
type=str,
default=None,
help="Path to destination directory",
)
parser.add_argument(
"--disable_no_upscale",
action="store_true",
help="Turn off no upscale, no_upscale default is True",
)
parser.add_argument(
"--max_reso_width",
type=int, default=512,
help="Max resolution width, default: 512",
)
parser.add_argument(
"--max_reso_height",
type=int, default=512,
help="Max resolution height, default: 512",
)
parser.add_argument(
"--min_size",
type=int, default=256,
help="Minimum size, default: 256",
)
parser.add_argument(
"--max_size",
type=int, default=1024,
help="Maximum size, default: 1024",
)
parser.add_argument(
"--reso_steps",
type=int, default=64,
help="Resolution steps, default: 64",
)
parser.add_argument(
"--face_detect",
action="store_true",
help="Detect faces in the image and crop to the face",
)
parser.add_argument(
"--face_detect_model",
type=str, default="./yolov8n-face.pt",
help="Path to the face detection model file, download from: https://github.com/akanametov/yolov8-face/releases/download/v0.0.0/yolov8n-face.pt",
)
parser.add_argument(
"--device",
type=str, default="cuda",
help="YOLO device, default: cuda",
)
parser.add_argument(
"--dry_run",
action="store_true",
help="Print out the image info without saving",
)
parser.add_argument(
"--debug",
action="store_true",
help="Print out the image info",
)
parser.add_argument(
"--override",
action="store_true",
help="Override existing file in the destination directory",
)
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args()
if args.src_dir is None:
parser.print_help(sys.stderr)
exit(1)
if args.dest_dir is None:
parser.print_help(sys.stderr)
exit(1)
if args.face_detect and args.face_detect_model is None:
print(f"Face detection model file is required.\n")
exit(1)
if args.face_detect and not os.path.exists(args.face_detect_model):
print(f"Face detection model file {args.face_detect_model} not found.\n")
exit(1)
bucket_manager = BucketManager(
no_upscale=False if args.disable_no_upscale else True,
max_reso=(args.max_reso_width, args.max_reso_height),
min_size=args.min_size,
max_size=args.max_size,
reso_steps=args.reso_steps,
)
if args.disable_no_upscale:
bucket_manager.make_buckets()
process_images(args, bucket_manager)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment