Skip to content

Instantly share code, notes, and snippets.

@zhaoweizhong
Created May 4, 2021 19:04
Show Gist options
  • Save zhaoweizhong/9dab05e5e404fd61793503c5d0ee81fd to your computer and use it in GitHub Desktop.
Save zhaoweizhong/9dab05e5e404fd61793503c5d0ee81fd to your computer and use it in GitHub Desktop.
Transform GTSDB Dataset Annotations to YOLO Format
import json
import argparse
from rich.progress import track
def load_txt(file_name):
file = open(file_name, 'r')
data = []
for line in file.readlines():
data.append(line.replace('\n', ''))
return data
def parse(data):
result_train = []
result_test = []
# Images and Annotations
count = 900
count_train = int(count * 0.7)
for annotation in track(data):
img_id = int(annotation.split(';')[0][:5])
img_name = annotation.split(';')[0][:5] + '.jpg'
xmin = int(annotation.split(';')[1])
ymin = int(annotation.split(';')[2])
xmax = int(annotation.split(';')[3])
ymax = int(annotation.split(';')[4])
class_id = annotation.split(';')[5]
if img_id < count_train:
if not bool([True for img in result_train if img.split('/')[3] == img_name]):
result_train.append('../gtsdb/images/' + img_name)
x_center = ((xmax - xmin) / 2 + xmin) / 1360
y_center = ((ymax - ymin) / 2 + ymin) / 800
width = (xmax - xmin) / 1360
height = (ymax - ymin) / 800
with open('labels/' + annotation.split(';')[0][:5] + '.txt', 'a') as f:
text = class_id + ' ' + str(x_center) + ' ' + str(y_center) + ' ' + str(width) + ' ' + str(height) + '\n'
f.write(text)
else:
if not bool([True for img in result_test if img.split('/')[3] == img_name]):
result_test.append('../gtsdb/images/' + img_name)
x_center = ((xmax - xmin) / 2 + xmin) / 1360
y_center = ((ymax - ymin) / 2 + ymin) / 800
width = (xmax - xmin) / 1360
height = (ymax - ymin) / 800
with open('labels/' + annotation.split(';')[0][:5] + '.txt', 'a') as f:
text = class_id + ' ' + str(x_center) + ' ' + str(y_center) + ' ' + str(width) + ' ' + str(height) + '\n'
f.write(text)
with open('yolo/train.txt', "w") as f:
result_train = [line+"\n" for line in result_train]
f.writelines(result_train)
with open('yolo/test.txt', "w") as f:
result_test = [line+"\n" for line in result_test]
f.writelines(result_test)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--file_name', type=str, default='gt.txt')
args = parser.parse_args()
file_name = args.file_name
data = load_txt(file_name)
parse(data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment