Skip to content

Instantly share code, notes, and snippets.

@jinyu121
Last active January 24, 2019 10:42
Show Gist options
  • Save jinyu121/02a505bb2547006ceb3fdbba3bee556c to your computer and use it in GitHub Desktop.
Save jinyu121/02a505bb2547006ceb3fdbba3bee556c to your computer and use it in GitHub Desktop.
A `torch.utils.data.Dataset` for `pytorch-yolo`

A Dataset for yolo2-pytorch

Minimal code modify.

If you have any better code, please let me know~ THX!!!!!

Update

  • [20180426] The original code
  • [20180427] Simple multi thread
import pickle
import uuid
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset
import numpy as np
import os
import scipy.sparse
from .imdb import ImageDataset
from .voc_eval import voc_eval
from .imdb import image_resize
# from functools import partial
# from utils.yolo import preprocess_train
class VOCDataset(ImageDataset, Dataset):
def __init__(self, imdb_name, datadir, batch_size, im_processor,
processes=3, shuffle=True, dst_size=None, classes=None, n_classes=None):
ImageDataset.__init__(self, imdb_name, datadir, batch_size, im_processor, processes, shuffle, dst_size)
Dataset.__init__(self)
meta = imdb_name.split('_')
self._year = meta[1]
self._image_set = meta[2]
self._devkit_path = os.path.join(datadir, 'VOCdevkit{}'.format(self._year))
self._data_path = os.path.join(self._devkit_path, 'VOC{}'.format(self._year))
assert os.path.exists(self._devkit_path), 'VOCdevkit path does not exist: {}'.format(self._devkit_path)
assert os.path.exists(self._data_path), 'Path does not exist: {}'.format(self._data_path)
if classes is None:
self._classes = ('aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
else:
self._classes = classes
if n_classes is not None:
self._classes = self._classes[:n_classes]
self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes)))))
self._image_ext = '.jpg'
self._salt = str(uuid.uuid4())
self._comp_id = 'comp4'
# PASCAL specific config options
self.config = {'cleanup': True, 'use_salt': True}
self.load_dataset()
# self.im_processor = partial(process_im,
# image_names=self._image_names, annotations=self._annotations)
# self.im_processor = preprocess_train
def __len__(self):
return len(self._image_indexes)
def __getitem__(self, index):
return index
# ===== Start =====
def fetch_betch_data(self, ith, x, size_index):
images, gt_boxes, classes, dontcare, origin_im = self._im_processor(
[self.image_names[x], self.get_annotation(x), self.dst_size], None)
# multi-scale
w, h = cfg.multi_scale_inp_size[size_index]
gt_boxes = np.asarray(gt_boxes, dtype=np.float)
if len(gt_boxes) > 0:
gt_boxes[:, 0::2] *= float(w) / images.shape[1]
gt_boxes[:, 1::2] *= float(h) / images.shape[0]
images = cv2.resize(images, (w, h))
self.batch['images'][ith] = images
self.batch['gt_boxes'][ith] = gt_boxes
self.batch['gt_classes'][ith] = classes
self.batch['dontcare'][ith] = dontcare
self.batch['origin_im'][ith] = origin_im
def parse(self, index, size_index):
index = index.numpy()
lenindex = len(index)
self.batch = {'images': [list()] * lenindex,
'gt_boxes': [list()] * lenindex,
'gt_classes': [list()] * lenindex,
'dontcare': [list()] * lenindex,
'origin_im': [list()] * lenindex}
ths = []
for ith in range(lenindex):
ths.append(threading.Thread(target=self.fetch_betch_data, args=(ith, index[ith], size_index)))
ths[ith].start()
for ith in range(lenindex):
ths[ith].join()
self.batch['images'] = np.asarray(self.batch['images'])
return self.batch
# ===== End =====
def load_dataset(self):
# set self._image_index and self._annotations
self._image_indexes = self._load_image_set_index()
self._image_names = [self.image_path_from_index(index) for index in self.image_indexes]
self._annotations = self._load_pascal_annotations()
def evaluate_detections(self, all_boxes, output_dir=None):
"""
all_boxes is a list of length number-of-classes.
Each list element is a list of length number-of-images.
Each of those list elements is either an empty list []
or a numpy array of detection.
all_boxes[class][image] = [] or np.array of shape #dets x 5
"""
self._write_voc_results_file(all_boxes)
self._do_python_eval(output_dir)
if self.config['cleanup']:
for cls in self._classes:
if cls == '__background__':
continue
filename = self._get_voc_results_file_template().format(cls)
os.remove(filename)
# -------------------------------------------------------------
def image_path_from_index(self, index):
"""
Construct an image path from the image's "index" identifier.
"""
image_path = os.path.join(self._data_path, 'JPEGImages', index + self._image_ext)
assert os.path.exists(image_path), 'Path does not exist: {}'.format(image_path)
return image_path
def _load_image_set_index(self):
"""
Load the indexes listed in this dataset's image set file.
"""
# Example path to image set file:
# self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main', self._image_set + '.txt')
assert os.path.exists(image_set_file), 'Path does not exist: {}'.format(image_set_file)
with open(image_set_file) as f:
image_index = [x.strip() for x in f.readlines()]
return image_index
def _load_pascal_annotations(self):
"""
Return the database of ground-truth regions of interest.
This function loads/saves from/to a cache file to speed up
future calls.
"""
cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = pickle.load(fid)
print('{} gt roidb loaded from {}'.format(self.name, cache_file))
return roidb
gt_roidb = [self._annotation_from_index(index) for index in self.image_indexes]
with open(cache_file, 'wb') as fid:
pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
print('wrote gt roidb to {}'.format(cache_file))
return gt_roidb
def _annotation_from_index(self, index):
"""
Load image and bounding boxes info from XML file in the PASCAL VOC
format.
"""
filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
tree = ET.parse(filename)
objs = tree.findall('object')
# if not self.config['use_diff']:
# # Exclude the samples labeled as difficult
# non_diff_objs = [
# obj for obj in objs if int(obj.find('difficult').text) == 0]
# # if len(non_diff_objs) != len(objs):
# # print 'Removed {} difficult objects'.format(
# # len(objs) - len(non_diff_objs))
# objs = non_diff_objs
num_objs = len(objs)
boxes = np.zeros((num_objs, 4), dtype=np.uint16)
gt_classes = np.zeros((num_objs), dtype=np.int32)
overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
# "Seg" area for pascal is just the box area
seg_areas = np.zeros((num_objs), dtype=np.float32)
ishards = np.zeros((num_objs), dtype=np.int32)
# Load object bounding boxes into a data frame.
for ix, obj in enumerate(objs):
bbox = obj.find('bndbox')
# Make pixel indexes 0-based
x1 = float(bbox.find('xmin').text) - 1
y1 = float(bbox.find('ymin').text) - 1
x2 = float(bbox.find('xmax').text) - 1
y2 = float(bbox.find('ymax').text) - 1
diffc = obj.find('difficult')
difficult = 0 if diffc is None else int(diffc.text)
ishards[ix] = difficult
cls = self._class_to_ind[obj.find('name').text.lower().strip()]
boxes[ix, :] = [x1, y1, x2, y2]
gt_classes[ix] = cls
overlaps[ix, cls] = 1.0
seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)
overlaps = scipy.sparse.csr_matrix(overlaps)
return {'boxes': boxes,
'gt_classes': gt_classes,
'gt_ishard': ishards,
'gt_overlaps': overlaps,
'flipped': False,
'seg_areas': seg_areas}
def _get_voc_results_file_template(self):
# VOCdevkit/results/VOC2007/Main/<comp_id>_det_test_aeroplane.txt
filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'
filedir = os.path.join(self._devkit_path, 'results', 'VOC' + self._year, 'Main')
if not os.path.exists(filedir):
os.makedirs(filedir)
path = os.path.join(filedir, filename)
return path
def _write_voc_results_file(self, all_boxes):
for cls_ind, cls in enumerate(self.classes):
if cls == '__background__':
continue
print('Writing {} VOC results file'.format(cls))
filename = self._get_voc_results_file_template().format(cls)
with open(filename, 'wt') as f:
for im_ind, index in enumerate(self.image_indexes):
dets = all_boxes[cls_ind][im_ind]
if dets == []:
continue
# the VOCdevkit expects 1-based indices
for k in range(dets.shape[0]):
f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
format(index, dets[k, -1],
dets[k, 0] + 1, dets[k, 1] + 1,
dets[k, 2] + 1, dets[k, 3] + 1))
def _do_python_eval(self, output_dir='output'):
annopath = os.path.join(
self._devkit_path,
'VOC' + self._year,
'Annotations',
'{:s}.xml')
imagesetfile = os.path.join(
self._devkit_path,
'VOC' + self._year,
'ImageSets',
'Main',
self._image_set + '.txt')
cachedir = os.path.join(self._devkit_path, 'annotations_cache')
aps = []
# The PASCAL VOC metric changed in 2010
use_07_metric = True if int(self._year) < 2010 else False
print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
if output_dir is not None and not os.path.isdir(output_dir):
os.mkdir(output_dir)
for i, cls in enumerate(self._classes):
if cls == '__background__':
continue
filename = self._get_voc_results_file_template().format(cls)
rec, prec, ap = voc_eval(
filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
use_07_metric=use_07_metric)
aps += [ap]
print(('AP for {} = {:.4f}'.format(cls, ap)))
if output_dir is not None:
with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
print(('Mean AP = {:.4f}'.format(np.mean(aps))))
print('~~~~~~~~')
print('Results:')
for ap in aps:
print(('{:.3f}'.format(ap)))
print(('{:.3f}'.format(np.mean(aps))))
print('~~~~~~~~')
print('')
print('--------------------------------------------------------------')
print('Results computed with the **unofficial** Python eval code.')
print('Results should be very close to the official MATLAB eval code.')
print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
print('-- Thanks, The Management')
print('--------------------------------------------------------------')
def _get_comp_id(self):
comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt']
else self._comp_id)
return comp_id
imdb = VOCDataset(imdb_name, cfg.DATA_DIR, cfg.batch_size, yolo_utils.preprocess_test, shuffle=False, dst_size=cfg.multi_scale_inp_size, n_classes=now_classes_high)
loader = DataLoader(imdb, batch_size=cfg.batch_size, shuffle=True, num_workers=5)
for iter, iter_data in enumerate(tqdm(loader, desc="Epoch {}".format(epoch))):
batch = imdb.parse(iter_data, size_index)
# ...
@longcw
Copy link

longcw commented Apr 26, 2018

You can move imdb.parse to collate_fn and convert numpy array to tensor. Then you will get the power of multiprocessing.
Another solution is to move _im_processor to get_item.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment