Created
December 16, 2019 03:51
-
-
Save AkashiSN/7301be858287ecb9cb4f98c8b8491e10 to your computer and use it in GitHub Desktop.
convert detectnet labeling format to yolov3 labeling format tool
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import os | |
import sys | |
import glob | |
from PIL import Image | |
classes = { | |
'Car' : 0, | |
'Van' : 1, | |
'Truck' :2, | |
'Pedestrian' : 3, | |
'Person_sitting' : 4, | |
'Cyclist' : 5, | |
'Tram' : 6, | |
'Misc' : 7, | |
'DontCare' : 8 | |
} | |
def convert(detectnet_data, img_width, img_height): | |
yolo_data = [] | |
for data in detectnet_data: | |
splits = data.split(" ") | |
class_idx = classes[splits[0]] | |
min_x = float(splits[4]) | |
min_y = float(splits[5]) | |
max_x = float(splits[6]) | |
max_y = float(splits[7]) | |
width = (max_x - min_x) / img_width | |
height = (max_y - min_y) / img_height | |
center_x = ((min_x + max_x) / 2) / img_width | |
center_y = ((min_y + max_y) / 2) / img_height | |
yolo_data.append([str(class_idx), str(center_x), str(center_y), str(width), str(height)]) | |
return yolo_data | |
if __name__ == "__main__": | |
if len(sys.argv) == 1: | |
print("need arg image dir and label dir") | |
sys.exit(1) | |
image_dir = sys.argv[1] | |
label_dir = sys.argv[2] | |
print("input image dir :", image_dir) | |
print("input label dir :", label_dir) | |
output_dir = os.path.join(os.getcwd(), "yolo_label") | |
print("output dir :", output_dir) | |
os.makedirs(output_dir, exist_ok=True) | |
label_files = glob.glob(os.path.join(label_dir, "*.txt")) | |
print("number of input files :", len(label_files)) | |
count = 0 | |
for label_file in label_files: | |
print(label_file) | |
with open(label_file, mode='r') as lf: | |
detectnet_data = lf.readlines() | |
lf.close() | |
label_filename = os.path.splitext(os.path.basename(label_file))[0] | |
im = Image.open(os.path.join(image_dir, label_filename+".png")) | |
w, h = im.size | |
yolo_data = convert(detectnet_data, w, h) | |
with open(os.path.join(output_dir, label_filename+".txt"), mode='w') as f: | |
for data in yolo_data: | |
print(data) | |
f.writelines(' '.join(data)) | |
f.write('\n') | |
f.close | |
count+=1 | |
print("number of output files :", count) | |
with open(os.path.join(output_dir, "classes.txt"), 'w') as f: | |
for cl in classes: | |
f.write(cl) | |
f.write('\n') | |
f.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment