Created March 5, 2019 20:05
Access Math Voc Triplet Data Loader
import os
import traceback
import time
import xml.etree.ElementTree as ET
import torch
from import Dataset
from data.utils import filter_by_shape, extend_bbox, image_channel_mean, bbox_to_quad
from data.transforms import TransformCrop, TransformPad, BinarizerWeightGeneration
import nnet.cfg as config
import cv2
import numpy as np
class AccessMathVOCTriplets(Dataset):
TrainValLectures = ["lecture_01", "lecture_06", "lecture_18", "NM_lecture_01", "NM_lecture_03"]
TestLectures = ["lecture_02", "lecture_07", "lecture_08", "lecture_10", "lecture_15", "NM_lecture_02",
def __init__(self, accessmathVOC_root='/run1/dataset-txt/AccessMathVOC/', data_transform=None, target_transform=None, lectures=None,
gt_type="bbox", imageset="trainval", seed=5, imageset_dir='Main'):
self.accessmathVOC_root = accessmathVOC_root
self.data_transform = data_transform
self.target_transform = target_transform
self.gt_type = gt_type
valid_imagesets = ["trainval", "train", "val", "test"]
self.imageset = "trainval" if imageset.lower() not in valid_imagesets else imageset.lower()
self.imageset_dir = imageset_dir
self.seed = seed
if self.imageset == "trainval":
self.lectures = AccessMathVOCTriplets.TrainValLectures
elif self.imageset in ["train", "val"]:
self.lectures = lectures if lectures is not None else AccessMathVOCTriplets.TrainValLectures
self.lectures = lectures if lectures is not None else AccessMathVOCTriplets.TestLectures
print('finding unique regions...')
self.region_frame_map = {}
max_height = 0
max_width = 0
max_area = 0
for lecture in self.lectures:
imageset_ = "trainval" if self.imageset in ["train", "val", "trainval"] else "test"
imageset_f = os.path.join(self.accessmathVOC_root, lecture,
"ImageSets", self.imageset_dir, "{}.txt".format(imageset_))
frame_ids = self._read_imageset_f(imageset_f)
frame_ids = sorted(frame_ids, key=lambda frame_id: int(frame_id))
for frame_id in frame_ids:
anno_path = self._id_to_anno_path(lecture, frame_id)
bboxes, region_ids = self._anno_from_xml(os.path.join(self.accessmathVOC_root, anno_path))
heights = bboxes[:, 3] - bboxes[:, 1]
widths = bboxes[:, 2] - bboxes[:, 0]
areas = heights * widths
max_height = heights.max() if heights.max() >= max_height else max_height
max_width = widths.max() if widths.max() >= max_width else max_width
max_area = areas.max() if areas.max() >= max_area else max_area
for i, region_id in enumerate(region_ids):
if (lecture, region_id) in self.region_frame_map:
self.region_frame_map[(lecture, region_id)] += [(frame_id, tuple(bboxes[i, :]))]
self.region_frame_map[(lecture, region_id)] = [(frame_id, tuple(bboxes[i, :]))]
self.unique_regions = list(self.region_frame_map.keys())
print('... found')
print('max region height:', max_height, 'width:', max_width, 'area:', max_area)
def __len__(self):
return len(self.unique_regions)
def __getitem__(self, item):
# pick a lecture and region_id
lecture, region_id = self.unique_regions[item]
# corner case: what if the box is somehow invalid (use filter_by_shape and rerun sampling)
all_bboxes = np.empty(shape=(0, 4))
while all_bboxes.shape[0] != 3:
anchor_frame, anchor_bbox, pos_frame, pos_bbox = self._get_positive_samples(lecture, region_id)
neg_frame, neg_bbox = self._get_negative_sample(lecture, region_id, anchor_frame, anchor_bbox)
all_bboxes = filter_by_shape(np.asarray([pos_bbox, anchor_bbox, neg_bbox]))
anchor_gt_path = self._id_to_anno_path(lecture, anchor_frame)
anchor_gt, _ = self._anno_from_xml(os.path.join(self.accessmathVOC_root, anchor_gt_path))
# get the image for all the frames
anchor_img = self._get_img(lecture, anchor_frame)
pos_img = self._get_img(lecture, pos_frame)
neg_img = self._get_img(lecture, neg_frame)
# convert the bbox tuples to np.array
anchor_bbox = np.asarray(anchor_bbox).reshape(1, -1)
pos_bbox = np.asarray(pos_bbox).reshape(1, -1)
neg_bbox = np.asarray(neg_bbox).reshape(1, -1)
triplet_sample = {"anchor_image": anchor_img, "anchor_bbox": anchor_bbox, "anchor_gt": anchor_gt,
"pos_image": pos_img, "pos_bbox": pos_bbox,
"neg_image": neg_img, "neg_bbox": neg_bbox
if self.gt_type == "quad":
all_bboxes = np.concatenate([anchor_bbox, pos_bbox, neg_bbox], axis=0)
all_quads = bbox_to_quad(all_bboxes)
triplet_sample["anchor_bbox"] = all_quads[0, :].reshape(1, -1)
triplet_sample["pos_bbox"] = all_quads[1, :].reshape(1, -1)
triplet_sample["neg_bbox"] = all_quads[2, :].reshape(1, -1)
if self.data_transform is not None:
triplet_sample = self.data_transform(triplet_sample)
if self.target_transform is not None:
triplet_sample = self.target_transform(triplet_sample)
return triplet_sample
def _read_imageset_f(fpath):
with open(fpath, 'r') as f:
lines = f.readlines()
lines = [l.strip() for l in lines if len(l.strip()) > 0]
return lines
def _id_to_image_path(lecture, _id):
return os.path.join(lecture, "JPEGImages", "{}.jpg".format(_id))
def _id_to_binary_image_path(lecture, _id):
return os.path.join(lecture, "binary", "{}.png".format(_id))
def _id_to_anno_path(lecture, _id):
return os.path.join(lecture, "Annotations", "{}.xml".format(_id))
def _anno_from_xml(fpath):
root = ET.parse(fpath).getroot()
bboxes = []
ids = []
for obj in root.iter('object'):
ids += [obj.find('ID').text]
bndbox = obj.find('bndbox')
bbox = []
for pt in ["xmin", "ymin", "xmax", "ymax"]:
coord = int(bndbox.find(pt).text)
bbox += [coord]
bboxes += [bbox]
return np.asarray(bboxes), ids
def _get_positive_samples(self, lecture, region_id):
positive_samples = self.region_frame_map[(lecture, region_id)]
# corner case: what if some region has only one frame of occurence
replacement = len(positive_samples) == 1
s1, s2 = np.random.choice(len(positive_samples), 2, replace=replacement)
anchor_frame, anchor_bbox = positive_samples[s1]
pos_frame, pos_bbox = positive_samples[s2]
return anchor_frame, anchor_bbox, pos_frame, pos_bbox
def _get_negative_sample(self, lecture, region_id, anchor_frame_id, anchor_bbox):
# pick a region with diff id from same lecture to generate neg sample s.t. it is within some distance of anchor
def bbox_distance(bbox1, bbox2):
bb1 = np.asarray(bbox1, dtype='float')
bb1[::2] /= 1920.
bb1[1::2] /= 1080
bb2 = np.asarray(bbox2, dtype='float')
bb2[::2] /= 1920.
bb2[1::2] /= 1080.
return np.linalg.norm(bb1 - bb2)
negative_regions = {k: v for k, v in self.region_frame_map.items() if k[0] == lecture and k[1] != region_id}
valid_negative_regions = {}
th = 0.33
while len(valid_negative_regions) < 1:
for k, v in negative_regions.items():
for (frame_id, bbox) in v:
if bbox_distance(bbox, anchor_bbox) <= th:
if k in valid_negative_regions:
valid_negative_regions[k] += [(frame_id, bbox)]
valid_negative_regions[k] = [(frame_id, bbox)]
th *= 2.
# pick a random negative region_id from valid_negative_region_ids
s3 = np.random.randint(0, len(valid_negative_regions))
# get list of frame ids and bboxes for the random negative region_id
negative_samples = list(valid_negative_regions.values())[s3]
# pick a random frame_id and bbox from negative samples
s4 = np.random.randint(0, len(negative_samples))
neg_frame, neg_bbox = negative_samples[s4]
return neg_frame, neg_bbox
def _get_img(self, lecture, frame_id):
img_path = self._id_to_image_path(lecture, frame_id)
return cv2.imread(os.path.join(self.accessmathVOC_root, img_path))
def detection_collate(batch):
#print("in collate")
#print(np.shape(batch), type(batch[0]["anchor_image"]), np.shape(batch[0]["anchor_image"]))
anchor_images = []
pos_images = []
neg_images = []
anchor_bboxes = []
pos_bboxes = []
neg_bboxes = []
for sample in batch:
anchor_images += [sample["anchor_image"][np.newaxis, :, :, :]]
pos_images += [sample["pos_image"][np.newaxis, :, :, :]]
neg_images += [sample["neg_image"][np.newaxis, :, :, :]]
anchor_bboxes += [sample["anchor_bbox"][np.newaxis, :, :]]
pos_bboxes += [sample["pos_bbox"][np.newaxis, :, :]]
neg_bboxes += [sample["neg_bbox"][np.newaxis, :, :]]
stacked_images = np.concatenate(anchor_images + pos_images + neg_images, axis=0)
stacked_bboxes = np.concatenate(anchor_bboxes + pos_bboxes + neg_bboxes, axis=0)
images = torch.from_numpy(stacked_images).permute(0, 3, 1, 2)
bboxes = torch.from_numpy(stacked_bboxes)
return images, bboxes
if __name__ == "__main__":
Example of creation of dataset object with train and val
from import DataLoader
from data.transforms import AMVOCTripletTransform
from nnet.layers.text_align_bbox import TextAlign
lectures = AccessMathVOCTriplets.TrainValLectures
v = np.random.randint(0, len(lectures))
train_lectures = lectures[: v] + lectures[v :]
val_lecture = [lectures[v]]
amvoc_triplets_train = AccessMathVOCTriplets('/run1/dataset-txt/AccessMathVOC', AMVOCTripletTransform(),
None, train_lectures, imageset='train')
amvoc_triplets_val = AccessMathVOCTriplets('/run1/dataset-txt/AccessMathVOC', AMVOCTripletTransform(),
None, val_lecture, imageset='val')
resample = TextAlign((1, 3, 30, 30), device='cpu', pool_type=0, rescale=1.0)
dataloader = DataLoader(amvoc_triplets_train, batch_size=5, collate_fn=AccessMathVOCTriplets.detection_collate)
dataiter = iter(dataloader)
for i, (images, boxes) in enumerate(dataiter):
print(images.shape, boxes.shape)
resampled = resample(images.float(), boxes.float())
print(resampled[0].shape, resampled[1])
if i == 2:
test_im = torch.tensor(np.random.random(size=(4, 3, 60, 60)))
n_rois = [5, 9, 2, 4]
all_test_coords = []
for n in n_rois:
test_coords = torch.tensor(np.random.random(size=(n, 4)))
test_coords[:, 2] = test_coords[:, 0] + 1.
test_coords[:, 3] = test_coords[:, 1] + 1.
test_coords *= 30.
all_test_coords += [test_coords.float()]
resampled2 = resample(test_im.float(), all_test_coords)
print(resampled2[0].shape, resampled2[1])
