Skip to content

Instantly share code, notes, and snippets.

@dvsseed
Last active March 18, 2020 08:52
Show Gist options
  • Save dvsseed/e34a8241cd5059e83bad0abb16609ede to your computer and use it in GitHub Desktop.
Save dvsseed/e34a8241cd5059e83bad0abb16609ede to your computer and use it in GitHub Desktop.
Generate two files of train and val
import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
# 這裡改[類別名稱],drug的項目有12個類別
classes = ["Binin-U", "Concor5mg", "Depyretin", "Diphenidol", "ligilin", "Lopedin", "Madopar125mg", "Madopar250mg", "Medicon-A", "Requip025", "spironolactone", "Treceton"]
def convert(size, box):
dw = 1. / size[0]
dh = 1. / size[1]
x = (box[0] + box[1]) / 2.0
y = (box[2] + box[3]) / 2.0
w = box[1] - box[0]
h = box[3] - box[2]
x = x * dw
w = w * dw
y = y * dh
h = h * dh
return (x, y, w, h)
def convert_annotation(image_id):
# 這裡改.xml資料夾的路徑
in_file = open('/home/user/darknet/drug_detect/Annotations/%s.xml' % (image_id))
# 這裡改產生每張圖片對應.txt的路徑
out_file = open('/home/user/darknet/drug_detect/labels/%s.txt' % (image_id), 'w')
tree = ET.parse(in_file)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
cls = obj.find('name').text
if cls not in classes :
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
bb = convert((w, h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
# 這裡改train.txt文件的路徑
image_ids_train = open('/home/user/darknet/drug_detect/train.txt').read().strip().split()
# 這裡改val.txt文件的路徑
image_ids_val = open('/home/user/darknet/drug_detect/val.txt').read().strip().split()
list_file_train = open('drug_train.txt', 'w')
list_file_val = open('drug_val.txt', 'w')
for image_id in image_ids_train:
# 這裡改樣本圖片資料夾的路徑
list_file_train.write('/home/user/darknet/drug_detect/JPEGImages/%s.jpg\n' % (image_id))
convert_annotation(image_id)
list_file_train.close()
for image_id in image_ids_val:
# 這裡改樣本圖片資料夾的路徑
list_file_val.write('/home/user/darknet/drug_detect/JPEGImages/%s.jpg\n' % (image_id))
convert_annotation(image_id)
list_file_val.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment