A Dataset
for yolo2-pytorch
Minimal code modify.
If you have any better code, please let me know~ THX!!!!!
- [20180426] The original code
- [20180427] Simple multi thread
A Dataset
for yolo2-pytorch
Minimal code modify.
If you have any better code, please let me know~ THX!!!!!
import pickle | |
import uuid | |
import xml.etree.ElementTree as ET | |
from torch.utils.data import Dataset | |
import numpy as np | |
import os | |
import scipy.sparse | |
from .imdb import ImageDataset | |
from .voc_eval import voc_eval | |
from .imdb import image_resize | |
# from functools import partial | |
# from utils.yolo import preprocess_train | |
class VOCDataset(ImageDataset, Dataset): | |
def __init__(self, imdb_name, datadir, batch_size, im_processor, | |
processes=3, shuffle=True, dst_size=None, classes=None, n_classes=None): | |
ImageDataset.__init__(self, imdb_name, datadir, batch_size, im_processor, processes, shuffle, dst_size) | |
Dataset.__init__(self) | |
meta = imdb_name.split('_') | |
self._year = meta[1] | |
self._image_set = meta[2] | |
self._devkit_path = os.path.join(datadir, 'VOCdevkit{}'.format(self._year)) | |
self._data_path = os.path.join(self._devkit_path, 'VOC{}'.format(self._year)) | |
assert os.path.exists(self._devkit_path), 'VOCdevkit path does not exist: {}'.format(self._devkit_path) | |
assert os.path.exists(self._data_path), 'Path does not exist: {}'.format(self._data_path) | |
if classes is None: | |
self._classes = ('aeroplane', 'bicycle', 'bird', 'boat', | |
'bottle', 'bus', 'car', 'cat', 'chair', | |
'cow', 'diningtable', 'dog', 'horse', | |
'motorbike', 'person', 'pottedplant', | |
'sheep', 'sofa', 'train', 'tvmonitor') | |
else: | |
self._classes = classes | |
if n_classes is not None: | |
self._classes = self._classes[:n_classes] | |
self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes))))) | |
self._image_ext = '.jpg' | |
self._salt = str(uuid.uuid4()) | |
self._comp_id = 'comp4' | |
# PASCAL specific config options | |
self.config = {'cleanup': True, 'use_salt': True} | |
self.load_dataset() | |
# self.im_processor = partial(process_im, | |
# image_names=self._image_names, annotations=self._annotations) | |
# self.im_processor = preprocess_train | |
def __len__(self): | |
return len(self._image_indexes) | |
def __getitem__(self, index): | |
return index | |
# ===== Start ===== | |
def fetch_betch_data(self, ith, x, size_index): | |
images, gt_boxes, classes, dontcare, origin_im = self._im_processor( | |
[self.image_names[x], self.get_annotation(x), self.dst_size], None) | |
# multi-scale | |
w, h = cfg.multi_scale_inp_size[size_index] | |
gt_boxes = np.asarray(gt_boxes, dtype=np.float) | |
if len(gt_boxes) > 0: | |
gt_boxes[:, 0::2] *= float(w) / images.shape[1] | |
gt_boxes[:, 1::2] *= float(h) / images.shape[0] | |
images = cv2.resize(images, (w, h)) | |
self.batch['images'][ith] = images | |
self.batch['gt_boxes'][ith] = gt_boxes | |
self.batch['gt_classes'][ith] = classes | |
self.batch['dontcare'][ith] = dontcare | |
self.batch['origin_im'][ith] = origin_im | |
def parse(self, index, size_index): | |
index = index.numpy() | |
lenindex = len(index) | |
self.batch = {'images': [list()] * lenindex, | |
'gt_boxes': [list()] * lenindex, | |
'gt_classes': [list()] * lenindex, | |
'dontcare': [list()] * lenindex, | |
'origin_im': [list()] * lenindex} | |
ths = [] | |
for ith in range(lenindex): | |
ths.append(threading.Thread(target=self.fetch_betch_data, args=(ith, index[ith], size_index))) | |
ths[ith].start() | |
for ith in range(lenindex): | |
ths[ith].join() | |
self.batch['images'] = np.asarray(self.batch['images']) | |
return self.batch | |
# ===== End ===== | |
def load_dataset(self): | |
# set self._image_index and self._annotations | |
self._image_indexes = self._load_image_set_index() | |
self._image_names = [self.image_path_from_index(index) for index in self.image_indexes] | |
self._annotations = self._load_pascal_annotations() | |
def evaluate_detections(self, all_boxes, output_dir=None): | |
""" | |
all_boxes is a list of length number-of-classes. | |
Each list element is a list of length number-of-images. | |
Each of those list elements is either an empty list [] | |
or a numpy array of detection. | |
all_boxes[class][image] = [] or np.array of shape #dets x 5 | |
""" | |
self._write_voc_results_file(all_boxes) | |
self._do_python_eval(output_dir) | |
if self.config['cleanup']: | |
for cls in self._classes: | |
if cls == '__background__': | |
continue | |
filename = self._get_voc_results_file_template().format(cls) | |
os.remove(filename) | |
# ------------------------------------------------------------- | |
def image_path_from_index(self, index): | |
""" | |
Construct an image path from the image's "index" identifier. | |
""" | |
image_path = os.path.join(self._data_path, 'JPEGImages', index + self._image_ext) | |
assert os.path.exists(image_path), 'Path does not exist: {}'.format(image_path) | |
return image_path | |
def _load_image_set_index(self): | |
""" | |
Load the indexes listed in this dataset's image set file. | |
""" | |
# Example path to image set file: | |
# self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt | |
image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main', self._image_set + '.txt') | |
assert os.path.exists(image_set_file), 'Path does not exist: {}'.format(image_set_file) | |
with open(image_set_file) as f: | |
image_index = [x.strip() for x in f.readlines()] | |
return image_index | |
def _load_pascal_annotations(self): | |
""" | |
Return the database of ground-truth regions of interest. | |
This function loads/saves from/to a cache file to speed up | |
future calls. | |
""" | |
cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl') | |
if os.path.exists(cache_file): | |
with open(cache_file, 'rb') as fid: | |
roidb = pickle.load(fid) | |
print('{} gt roidb loaded from {}'.format(self.name, cache_file)) | |
return roidb | |
gt_roidb = [self._annotation_from_index(index) for index in self.image_indexes] | |
with open(cache_file, 'wb') as fid: | |
pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL) | |
print('wrote gt roidb to {}'.format(cache_file)) | |
return gt_roidb | |
def _annotation_from_index(self, index): | |
""" | |
Load image and bounding boxes info from XML file in the PASCAL VOC | |
format. | |
""" | |
filename = os.path.join(self._data_path, 'Annotations', index + '.xml') | |
tree = ET.parse(filename) | |
objs = tree.findall('object') | |
# if not self.config['use_diff']: | |
# # Exclude the samples labeled as difficult | |
# non_diff_objs = [ | |
# obj for obj in objs if int(obj.find('difficult').text) == 0] | |
# # if len(non_diff_objs) != len(objs): | |
# # print 'Removed {} difficult objects'.format( | |
# # len(objs) - len(non_diff_objs)) | |
# objs = non_diff_objs | |
num_objs = len(objs) | |
boxes = np.zeros((num_objs, 4), dtype=np.uint16) | |
gt_classes = np.zeros((num_objs), dtype=np.int32) | |
overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32) | |
# "Seg" area for pascal is just the box area | |
seg_areas = np.zeros((num_objs), dtype=np.float32) | |
ishards = np.zeros((num_objs), dtype=np.int32) | |
# Load object bounding boxes into a data frame. | |
for ix, obj in enumerate(objs): | |
bbox = obj.find('bndbox') | |
# Make pixel indexes 0-based | |
x1 = float(bbox.find('xmin').text) - 1 | |
y1 = float(bbox.find('ymin').text) - 1 | |
x2 = float(bbox.find('xmax').text) - 1 | |
y2 = float(bbox.find('ymax').text) - 1 | |
diffc = obj.find('difficult') | |
difficult = 0 if diffc is None else int(diffc.text) | |
ishards[ix] = difficult | |
cls = self._class_to_ind[obj.find('name').text.lower().strip()] | |
boxes[ix, :] = [x1, y1, x2, y2] | |
gt_classes[ix] = cls | |
overlaps[ix, cls] = 1.0 | |
seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1) | |
overlaps = scipy.sparse.csr_matrix(overlaps) | |
return {'boxes': boxes, | |
'gt_classes': gt_classes, | |
'gt_ishard': ishards, | |
'gt_overlaps': overlaps, | |
'flipped': False, | |
'seg_areas': seg_areas} | |
def _get_voc_results_file_template(self): | |
# VOCdevkit/results/VOC2007/Main/<comp_id>_det_test_aeroplane.txt | |
filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt' | |
filedir = os.path.join(self._devkit_path, 'results', 'VOC' + self._year, 'Main') | |
if not os.path.exists(filedir): | |
os.makedirs(filedir) | |
path = os.path.join(filedir, filename) | |
return path | |
def _write_voc_results_file(self, all_boxes): | |
for cls_ind, cls in enumerate(self.classes): | |
if cls == '__background__': | |
continue | |
print('Writing {} VOC results file'.format(cls)) | |
filename = self._get_voc_results_file_template().format(cls) | |
with open(filename, 'wt') as f: | |
for im_ind, index in enumerate(self.image_indexes): | |
dets = all_boxes[cls_ind][im_ind] | |
if dets == []: | |
continue | |
# the VOCdevkit expects 1-based indices | |
for k in range(dets.shape[0]): | |
f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'. | |
format(index, dets[k, -1], | |
dets[k, 0] + 1, dets[k, 1] + 1, | |
dets[k, 2] + 1, dets[k, 3] + 1)) | |
def _do_python_eval(self, output_dir='output'): | |
annopath = os.path.join( | |
self._devkit_path, | |
'VOC' + self._year, | |
'Annotations', | |
'{:s}.xml') | |
imagesetfile = os.path.join( | |
self._devkit_path, | |
'VOC' + self._year, | |
'ImageSets', | |
'Main', | |
self._image_set + '.txt') | |
cachedir = os.path.join(self._devkit_path, 'annotations_cache') | |
aps = [] | |
# The PASCAL VOC metric changed in 2010 | |
use_07_metric = True if int(self._year) < 2010 else False | |
print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No')) | |
if output_dir is not None and not os.path.isdir(output_dir): | |
os.mkdir(output_dir) | |
for i, cls in enumerate(self._classes): | |
if cls == '__background__': | |
continue | |
filename = self._get_voc_results_file_template().format(cls) | |
rec, prec, ap = voc_eval( | |
filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5, | |
use_07_metric=use_07_metric) | |
aps += [ap] | |
print(('AP for {} = {:.4f}'.format(cls, ap))) | |
if output_dir is not None: | |
with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f: | |
pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f) | |
print(('Mean AP = {:.4f}'.format(np.mean(aps)))) | |
print('~~~~~~~~') | |
print('Results:') | |
for ap in aps: | |
print(('{:.3f}'.format(ap))) | |
print(('{:.3f}'.format(np.mean(aps)))) | |
print('~~~~~~~~') | |
print('') | |
print('--------------------------------------------------------------') | |
print('Results computed with the **unofficial** Python eval code.') | |
print('Results should be very close to the official MATLAB eval code.') | |
print('Recompute with `./tools/reval.py --matlab ...` for your paper.') | |
print('-- Thanks, The Management') | |
print('--------------------------------------------------------------') | |
def _get_comp_id(self): | |
comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt'] | |
else self._comp_id) | |
return comp_id |
imdb = VOCDataset(imdb_name, cfg.DATA_DIR, cfg.batch_size, yolo_utils.preprocess_test, shuffle=False, dst_size=cfg.multi_scale_inp_size, n_classes=now_classes_high) | |
loader = DataLoader(imdb, batch_size=cfg.batch_size, shuffle=True, num_workers=5) | |
for iter, iter_data in enumerate(tqdm(loader, desc="Epoch {}".format(epoch))): | |
batch = imdb.parse(iter_data, size_index) | |
# ... |
You can move
imdb.parse
tocollate_fn
and convert numpy array to tensor. Then you will get the power of multiprocessing.Another solution is to move
_im_processor
toget_item
.