convert detectnet labeling format to yolov3 labeling format tool
#!/usr/bin/env python3 | |
import os | |
import sys | |
import glob | |
from PIL import Image | |
classes = { | |
'Car' : 0, | |
'Van' : 1, | |
'Truck' :2, | |
'Pedestrian' : 3, | |
'Person_sitting' : 4, | |
'Cyclist' : 5, | |
'Tram' : 6, | |
'Misc' : 7, | |
'DontCare' : 8 | |
} | |
def convert(detectnet_data, img_width, img_height): | |
yolo_data = [] | |
for data in detectnet_data: | |
splits = data.split(" ") | |
class_idx = classes[splits[0]] | |
min_x = float(splits[4]) | |
min_y = float(splits[5]) | |
max_x = float(splits[6]) | |
max_y = float(splits[7]) | |
width = (max_x - min_x) / img_width | |
height = (max_y - min_y) / img_height | |
center_x = ((min_x + max_x) / 2) / img_width | |
center_y = ((min_y + max_y) / 2) / img_height | |
yolo_data.append([str(class_idx), str(center_x), str(center_y), str(width), str(height)]) | |
return yolo_data | |
if __name__ == "__main__": | |
if len(sys.argv) == 1: | |
print("need arg image dir and label dir") | |
sys.exit(1) | |
image_dir = sys.argv[1] | |
label_dir = sys.argv[2] | |
print("input image dir :", image_dir) | |
print("input label dir :", label_dir) | |
output_dir = os.path.join(os.getcwd(), "yolo_label") | |
print("output dir :", output_dir) | |
os.makedirs(output_dir, exist_ok=True) | |
label_files = glob.glob(os.path.join(label_dir, "*.txt")) | |
print("number of input files :", len(label_files)) | |
count = 0 | |
for label_file in label_files: | |
print(label_file) | |
with open(label_file, mode='r') as lf: | |
detectnet_data = lf.readlines() | |
lf.close() | |
label_filename = os.path.splitext(os.path.basename(label_file))[0] | |
im = Image.open(os.path.join(image_dir, label_filename+".png")) | |
w, h = im.size | |
yolo_data = convert(detectnet_data, w, h) | |
with open(os.path.join(output_dir, label_filename+".txt"), mode='w') as f: | |
for data in yolo_data: | |
print(data) | |
f.writelines(' '.join(data)) | |
f.write('\n') | |
f.close | |
count+=1 | |
print("number of output files :", count) | |
with open(os.path.join(output_dir, "classes.txt"), 'w') as f: | |
for cl in classes: | |
f.write(cl) | |
f.write('\n') | |
f.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment