Skip to content

Instantly share code, notes, and snippets.

@zhaoweizhong
Created May 3, 2021 21:31
Show Gist options
  • Save zhaoweizhong/60050a0972a9169c3a1825ebf39a6488 to your computer and use it in GitHub Desktop.
Save zhaoweizhong/60050a0972a9169c3a1825ebf39a6488 to your computer and use it in GitHub Desktop.
Transform Tsinghua-Tencent 100K Dataset (ver 2021) Annotations to YOLO Format
import json
import argparse
import copy
def load_json(file_name):
file = open(file_name, 'r').read()
return json.loads(file)
def parse(data):
# Categories
categories = ["pl80", "w9", "p6", "ph4.2", "i8", "w14", "w33", "pa13", "im", "w58", "pl90", "il70", "p5", "pm55", "pl60", "ip", "p11", "pdd", "wc", "i2r", "w30", "pmr", "p23", "pl15", "pm10", "pss", "w1", "p4", "w38", "w50", "w34", "pw3.5", "iz", "w39", "w11", "p1n", "pr70", "pd", "pnl", "pg", "ph5.3", "w66", "il80", "pb", "pbm", "pm5", "w24", "w67", "w49", "pm40", "ph4", "w45", "i4", "w37", "ph2.6", "pl70", "ph5.5", "i14", "i11", "p7", "p29", "pne", "pr60", "pm13", "ph4.5", "p12", "p3", "w40", "pl5", "w13", "pr10", "p14", "i4l", "pr30", "pw4.2", "w16", "p17", "ph3", "i9", "w15", "w35", "pa8", "pt", "pr45", "w17", "pl30", "pcs", "pctl", "pr50", "ph4.4", "pm46", "pm35", "i15", "pa12", "pclr", "i1", "pcd", "pbp", "pcr", "w28", "ps", "pm8", "w18", "w2", "w52", "ph2.9", "ph1.8", "pe", "p20", "w36", "p10", "pn", "pa14", "w54", "ph3.2", "p2", "ph2.5", "w62", "w55", "pw3", "pw4.5", "i12", "ph4.3", "phclr", "i10", "pr5", "i13", "w10", "p26", "w26", "p8", "w5", "w42", "il50", "p13", "pr40", "p25", "w41", "pl20", "ph4.8", "pnlc", "ph3.3", "w29", "ph2.1", "w53", "pm30", "p24", "p21", "pl40", "w27", "pmb", "pc", "i6", "pr20", "p18", "ph3.8", "pm50", "pm25", "i2", "w22", "w47", "w56", "pl120", "ph2.8", "i7", "w12", "pm1.5", "pm2.5", "w32", "pm15", "ph5", "w19", "pw3.2", "pw2.5", "pl10", "il60", "w57", "w48", "w60", "pl100", "pr80", "p16", "pl110", "w59", "w64", "w20", "ph2", "p9", "il100", "w31", "w65", "ph2.4", "pr100", "p19", "ph3.5", "pa10", "pcl", "pl35", "p15", "w7", "pa6", "phcs", "w43", "p28", "w6", "w3", "w25", "pl25", "il110", "p1", "w46", "pn-2", "w51", "w44", "w63", "w23", "pm20", "w8", "pmblr", "w4", "i5", "il90", "w21", "p27", "pl50", "pl65", "w61", "ph2.2", "pm2", "i3", "pa18", "pw4"]
result_train = []
result_test = []
# Images and Annotations
count = len(data['imgs'])
count_train = int(count * 0.8)
count_test = count - count_train
i = 1
for img in data['imgs']:
if i <= count_train:
flag = False
for box in data['imgs'][img]['objects']:
if box['category'] in categories:
flag = True
if flag:
img_id = data['imgs'][img]['id']
result_train.append('../tt100k_2021/images/' + data['imgs'][img]['path'].split('/')[1])
annotations = []
for box in data['imgs'][img]['objects']:
if box['category'] in categories:
x_min = box['bbox']['xmin']
x_max = box['bbox']['xmax']
y_min = box['bbox']['ymin']
y_max = box['bbox']['ymax']
x_center = ((x_max - x_min) / 2 + x_min) / 2048
y_center = ((y_max - y_min) / 2 + y_min) / 2048
width = (x_max - x_min) / 2048
height = (y_max - y_min) / 2048
annotations.append([categories.index(box['category']), x_center, y_center, width, height])
with open('labels/' + data['imgs'][img]['path'].split('/')[1].split('.')[0] + '.txt', 'w') as f:
for annotation in annotations:
text = str(annotation).strip('[').strip(']').replace(',','').replace('\'','')+'\n'
f.write(text)
else:
flag = False
for box in data['imgs'][img]['objects']:
if box['category'] in categories:
flag = True
if flag:
img_id = data['imgs'][img]['id']
result_test.append('../tt100k_2021/images/' + data['imgs'][img]['path'].split('/')[1])
annotations = []
for box in data['imgs'][img]['objects']:
if box['category'] in categories:
x_min = box['bbox']['xmin']
x_max = box['bbox']['xmax']
y_min = box['bbox']['ymin']
y_max = box['bbox']['ymax']
x_center = ((x_max - x_min) / 2 + x_min) / 2048
y_center = ((y_max - y_min) / 2 + y_min) / 2048
width = (x_max - x_min) / 2048
height = (y_max - y_min) / 2048
annotations.append([categories.index(box['category']), x_center, y_center, width, height])
with open('labels/' + data['imgs'][img]['path'].split('/')[1].split('.')[0] + '.txt', 'w') as f:
for annotation in annotations:
text = str(annotation).strip('[').strip(']').replace(',','').replace('\'','')+'\n'
f.write(text)
i = i + 1
with open('yolo/train.txt', "w") as f:
result_train = [line+"\n" for line in result_train]
f.writelines(result_train)
with open('yolo/test.txt', "w") as f:
result_test = [line+"\n" for line in result_test]
f.writelines(result_test)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--file_name', type=str, default='annotations_all.json')
args = parser.parse_args()
file_name = args.file_name
data = load_json(file_name)
parse(data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment