Last active
April 10, 2023 13:23
-
-
Save e96031413/7057d1a6716816f4718bb2a0dbc7ea55 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import random | |
from shutil import copyfile | |
import numpy as np | |
from PIL import Image | |
import multiprocessing as mp | |
class DarkFace2YOLOv5: | |
def __init__(self, data_dir, class_list, train_ratio=0.8, random_seed=None): | |
self.data_dir = data_dir | |
self.class_list = class_list | |
self.train_ratio = train_ratio | |
self.random_seed = random_seed | |
def split_dataset(self): | |
image_dir = os.path.join(self.data_dir, "image") | |
label_dir = os.path.join(self.data_dir, "label") | |
image_paths = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".png")]) | |
label_paths = sorted([os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith(".txt")]) | |
assert len(image_paths) == len(label_paths), "影像與標籤的數量不符" | |
if self.random_seed is not None: | |
random.seed(self.random_seed) | |
paired_paths = list(zip(image_paths, label_paths)) | |
random.shuffle(paired_paths) | |
num_train = int(len(paired_paths) * self.train_ratio) | |
def path_generator(): | |
for i, (image_path, label_path) in enumerate(paired_paths): | |
yield (image_path, label_path, i < num_train) | |
return path_generator() | |
def convert_annotation(self, label_path, image_size): | |
with open(label_path, "r") as f: | |
lines = f.readlines() | |
coords = np.array([list(map(int, line.strip().split())) for line in lines[1:]]) | |
class_id = np.zeros((coords.shape[0], 1), dtype=int) | |
xy_center = (coords[:, 0:2] + coords[:, 2:4]) / 2 / image_size | |
wh = (coords[:, 2:4] - coords[:, 0:2]) / image_size | |
annotations = np.hstack((class_id, xy_center, wh)) | |
return annotations | |
def process_image(self, image_path, label_path, is_train): | |
image_name = os.path.basename(image_path).split(".")[0] | |
with Image.open(image_path) as img: | |
image_size = max(img.size) | |
annotations = self.convert_annotation(label_path, image_size) | |
with open(os.path.join(self.data_dir, "labels", image_name + ".txt"), "w") as f: | |
for annotation in annotations: | |
f.write("{} {:.6f} {:.6f} {:.6f} {:.6f}\n".format(*annotation)) | |
dest_dir = os.path.join(self.data_dir, "images", "train" if is_train else "val") | |
os.makedirs(dest_dir, exist_ok=True) | |
dest_path = os.path.join(dest_dir, os.path.basename(image_path)) | |
copyfile(image_path, dest_path) | |
def convert_dataset(self): | |
path_gen = self.split_dataset() | |
with mp.Pool() as pool: | |
pool.starmap(self.process_image, path_gen) | |
print("標註已轉換為 YOLOv5 格式") | |
if __name__ == "__main__": | |
data_dir = "your_data_directory" # 請將此處替換為您的實際資料 | |
class_list = ["face"] # 在此範例中,我們假設只有一個類別:人臉 | |
random_seed = 42 | |
converter = DarkFace2YOLOv5(data_dir, class_list, random_seed=random_seed) | |
# 轉換資料集的標註格式 | |
converter.convert_dataset() | |
print("資料集已成功轉換為 YOLOv5 格式") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment