Skip to content

Instantly share code, notes, and snippets.

@jinyu121
Created March 29, 2017 11:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jinyu121/8f3cd4dd1adaed4108cbe754aecfb794 to your computer and use it in GitHub Desktop.
Save jinyu121/8f3cd4dd1adaed4108cbe754aecfb794 to your computer and use it in GitHub Desktop.
实验室数据集
# -*- coding: utf-8 -*-
import os
import re
import random
import lxml.etree as ElementTree
import dicttoxml
from xml.dom.minidom import parseString
from collections import OrderedDict
def dict_to_elem(dictionary):
item = ElementTree.Element('Item')
for key in dictionary:
field = ElementTree.Element(key.strip().replace(' ', '_'))
field.text = dictionary[key]
item.append(field)
return item
if "__main__" == __name__:
BASE_DIR = "/home/haoyu/VOCdevkit/VOC2007/"
MIN_X = 1
MIN_Y = 1
MAX_X = 640
MAX_Y = 480
ground_truth_file = os.path.join(BASE_DIR,"Groundtruth","groundtruth.txt")
pattern_ith = re.compile(r"^(\d+):")
pattern_roi = re.compile(r"\[(\d+),(\d+),(\d+),(\d+)\]")
filename_jar = list()
sets = {
'train': 0.8,
'val': 0.1,
'test': 0.1,
}
with open(ground_truth_file, 'r') as f:
for line in f:
print(line)
ith = pattern_ith.match(line)
filename_jar.append(ith.group(1))
rois = pattern_roi.findall(line)
AnnotationFile = os.path.join(BASE_DIR,"Annotations",
"{}_{}".format("rgb", ith.group(1)) + ".xml")
data = {
'folder': "VOC2007",
'filename': "{}_{}".format("rgb", ith.group(1)) + ".png",
'size': {
'width': 640,
'height': 480,
'depth': 3,
},
'segmented': 0
}
xml = dicttoxml.dicttoxml(OrderedDict(data), attr_type=False, custom_root='annotation')
dom = parseString(xml)
for roi in rois:
x1 = max(int(roi[0]),MIN_X)
y1 = max(int(roi[1]),MIN_Y)
x2 = min(int(roi[2]),MAX_X)
y2 = min(int(roi[3]),MAX_Y)
obj = {
'name': 'person',
'pose': 'Left',
'truncated': 1,
'difficult': 0,
'bndbox': {
'xmin': x1,
'ymin': y1,
'xmax': x2,
'ymax': y2,
}
}
assert MIN_X<=int(x1)<=MAX_X ,"{}".format(x1)
assert MIN_Y<=int(y1)<=MAX_Y ,"{}".format(y1)
assert MIN_X<=int(x2)<=MAX_X ,"{}".format(x2)
assert MIN_Y<=int(y2)<=MAX_Y ,"{}".format(y2)
assert x1<x2 , "{} {}"%(x1,x2)
assert y1<y2 , "{} {}"%(y1,y2)
xml_obj = parseString(dicttoxml.dicttoxml(OrderedDict(obj), attr_type=False, custom_root='object'))
x = dom.importNode(xml_obj.childNodes[0], True)
dom.childNodes[0].appendChild(x)
with open(AnnotationFile, "w") as anno:
print(dom.toprettyxml(), file=anno)
# 分数据集
total = len(filename_jar)
random.shuffle(filename_jar)
sets_counter = 0
for (set_name, set_scale) in sets.items():
with open(os.path.join(BASE_DIR,"ImageSets","Main", set_name + ".txt"), 'w') as st:
tot = int(total * set_scale)
for ith in range(sets_counter, sets_counter + tot):
print("{}_{}".format("rgb", filename_jar[ith]), file=st)
sets_counter += tot
with open(os.path.join(BASE_DIR,"ImageSets","Main","trainval.txt"), 'w') as train_val:
for set_name in ["train","val"]:
for line in open(os.path.join(BASE_DIR,"ImageSets","Main", set_name + ".txt"), 'r') :
print(line,end="",file=train_val)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment