Skip to content

Instantly share code, notes, and snippets.

@crazysal
Created March 5, 2019 20:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save crazysal/5732e71119205de892bc4c94cfa0e2ce to your computer and use it in GitHub Desktop.
Save crazysal/5732e71119205de892bc4c94cfa0e2ce to your computer and use it in GitHub Desktop.
Access Math Voc Triplet Data Loader
import os
import traceback
import time
import xml.etree.ElementTree as ET
import torch
from torch.utils.data import Dataset
from data.utils import filter_by_shape, extend_bbox, image_channel_mean, bbox_to_quad
from data.transforms import TransformCrop, TransformPad, BinarizerWeightGeneration
import nnet.cfg as config
import cv2
import numpy as np
class AccessMathVOCTriplets(Dataset):
TrainValLectures = ["lecture_01", "lecture_06", "lecture_18", "NM_lecture_01", "NM_lecture_03"]
TestLectures = ["lecture_02", "lecture_07", "lecture_08", "lecture_10", "lecture_15", "NM_lecture_02",
"NM_lecture_05"]
def __init__(self, accessmathVOC_root='/run1/dataset-txt/AccessMathVOC/', data_transform=None, target_transform=None, lectures=None,
gt_type="bbox", imageset="trainval", seed=5, imageset_dir='Main'):
self.accessmathVOC_root = accessmathVOC_root
self.data_transform = data_transform
self.target_transform = target_transform
self.gt_type = gt_type
valid_imagesets = ["trainval", "train", "val", "test"]
self.imageset = "trainval" if imageset.lower() not in valid_imagesets else imageset.lower()
self.imageset_dir = imageset_dir
self.seed = seed
if self.imageset == "trainval":
self.lectures = AccessMathVOCTriplets.TrainValLectures
elif self.imageset in ["train", "val"]:
self.lectures = lectures if lectures is not None else AccessMathVOCTriplets.TrainValLectures
else:
self.lectures = lectures if lectures is not None else AccessMathVOCTriplets.TestLectures
print('finding unique regions...')
self.region_frame_map = {}
max_height = 0
max_width = 0
max_area = 0
for lecture in self.lectures:
imageset_ = "trainval" if self.imageset in ["train", "val", "trainval"] else "test"
imageset_f = os.path.join(self.accessmathVOC_root, lecture,
"ImageSets", self.imageset_dir, "{}.txt".format(imageset_))
frame_ids = self._read_imageset_f(imageset_f)
frame_ids = sorted(frame_ids, key=lambda frame_id: int(frame_id))
for frame_id in frame_ids:
anno_path = self._id_to_anno_path(lecture, frame_id)
bboxes, region_ids = self._anno_from_xml(os.path.join(self.accessmathVOC_root, anno_path))
heights = bboxes[:, 3] - bboxes[:, 1]
widths = bboxes[:, 2] - bboxes[:, 0]
areas = heights * widths
max_height = heights.max() if heights.max() >= max_height else max_height
max_width = widths.max() if widths.max() >= max_width else max_width
max_area = areas.max() if areas.max() >= max_area else max_area
for i, region_id in enumerate(region_ids):
if (lecture, region_id) in self.region_frame_map:
self.region_frame_map[(lecture, region_id)] += [(frame_id, tuple(bboxes[i, :]))]
else:
self.region_frame_map[(lecture, region_id)] = [(frame_id, tuple(bboxes[i, :]))]
self.unique_regions = list(self.region_frame_map.keys())
print('... found')
print('max region height:', max_height, 'width:', max_width, 'area:', max_area)
def __len__(self):
return len(self.unique_regions)
def __getitem__(self, item):
# pick a lecture and region_id
lecture, region_id = self.unique_regions[item]
# corner case: what if the box is somehow invalid (use filter_by_shape and rerun sampling)
all_bboxes = np.empty(shape=(0, 4))
while all_bboxes.shape[0] != 3:
anchor_frame, anchor_bbox, pos_frame, pos_bbox = self._get_positive_samples(lecture, region_id)
neg_frame, neg_bbox = self._get_negative_sample(lecture, region_id, anchor_frame, anchor_bbox)
all_bboxes = filter_by_shape(np.asarray([pos_bbox, anchor_bbox, neg_bbox]))
anchor_gt_path = self._id_to_anno_path(lecture, anchor_frame)
anchor_gt, _ = self._anno_from_xml(os.path.join(self.accessmathVOC_root, anchor_gt_path))
# get the image for all the frames
anchor_img = self._get_img(lecture, anchor_frame)
pos_img = self._get_img(lecture, pos_frame)
neg_img = self._get_img(lecture, neg_frame)
# convert the bbox tuples to np.array
anchor_bbox = np.asarray(anchor_bbox).reshape(1, -1)
pos_bbox = np.asarray(pos_bbox).reshape(1, -1)
neg_bbox = np.asarray(neg_bbox).reshape(1, -1)
triplet_sample = {"anchor_image": anchor_img, "anchor_bbox": anchor_bbox, "anchor_gt": anchor_gt,
"pos_image": pos_img, "pos_bbox": pos_bbox,
"neg_image": neg_img, "neg_bbox": neg_bbox
}
if self.gt_type == "quad":
all_bboxes = np.concatenate([anchor_bbox, pos_bbox, neg_bbox], axis=0)
all_quads = bbox_to_quad(all_bboxes)
triplet_sample["anchor_bbox"] = all_quads[0, :].reshape(1, -1)
triplet_sample["pos_bbox"] = all_quads[1, :].reshape(1, -1)
triplet_sample["neg_bbox"] = all_quads[2, :].reshape(1, -1)
if self.data_transform is not None:
triplet_sample = self.data_transform(triplet_sample)
if self.target_transform is not None:
triplet_sample = self.target_transform(triplet_sample)
return triplet_sample
@staticmethod
def _read_imageset_f(fpath):
with open(fpath, 'r') as f:
lines = f.readlines()
lines = [l.strip() for l in lines if len(l.strip()) > 0]
return lines
@staticmethod
def _id_to_image_path(lecture, _id):
return os.path.join(lecture, "JPEGImages", "{}.jpg".format(_id))
@staticmethod
def _id_to_binary_image_path(lecture, _id):
return os.path.join(lecture, "binary", "{}.png".format(_id))
@staticmethod
def _id_to_anno_path(lecture, _id):
return os.path.join(lecture, "Annotations", "{}.xml".format(_id))
@staticmethod
def _anno_from_xml(fpath):
root = ET.parse(fpath).getroot()
bboxes = []
ids = []
for obj in root.iter('object'):
ids += [obj.find('ID').text]
bndbox = obj.find('bndbox')
bbox = []
for pt in ["xmin", "ymin", "xmax", "ymax"]:
coord = int(bndbox.find(pt).text)
bbox += [coord]
bboxes += [bbox]
return np.asarray(bboxes), ids
def _get_positive_samples(self, lecture, region_id):
positive_samples = self.region_frame_map[(lecture, region_id)]
# corner case: what if some region has only one frame of occurence
replacement = len(positive_samples) == 1
s1, s2 = np.random.choice(len(positive_samples), 2, replace=replacement)
anchor_frame, anchor_bbox = positive_samples[s1]
pos_frame, pos_bbox = positive_samples[s2]
return anchor_frame, anchor_bbox, pos_frame, pos_bbox
def _get_negative_sample(self, lecture, region_id, anchor_frame_id, anchor_bbox):
# pick a region with diff id from same lecture to generate neg sample s.t. it is within some distance of anchor
def bbox_distance(bbox1, bbox2):
bb1 = np.asarray(bbox1, dtype='float')
bb1[::2] /= 1920.
bb1[1::2] /= 1080
bb2 = np.asarray(bbox2, dtype='float')
bb2[::2] /= 1920.
bb2[1::2] /= 1080.
return np.linalg.norm(bb1 - bb2)
negative_regions = {k: v for k, v in self.region_frame_map.items() if k[0] == lecture and k[1] != region_id}
valid_negative_regions = {}
th = 0.33
while len(valid_negative_regions) < 1:
for k, v in negative_regions.items():
for (frame_id, bbox) in v:
if bbox_distance(bbox, anchor_bbox) <= th:
if k in valid_negative_regions:
valid_negative_regions[k] += [(frame_id, bbox)]
else:
valid_negative_regions[k] = [(frame_id, bbox)]
th *= 2.
# pick a random negative region_id from valid_negative_region_ids
s3 = np.random.randint(0, len(valid_negative_regions))
# get list of frame ids and bboxes for the random negative region_id
negative_samples = list(valid_negative_regions.values())[s3]
# pick a random frame_id and bbox from negative samples
s4 = np.random.randint(0, len(negative_samples))
neg_frame, neg_bbox = negative_samples[s4]
return neg_frame, neg_bbox
def _get_img(self, lecture, frame_id):
img_path = self._id_to_image_path(lecture, frame_id)
return cv2.imread(os.path.join(self.accessmathVOC_root, img_path))
@staticmethod
def detection_collate(batch):
#print("in collate")
#print(np.shape(batch), type(batch[0]["anchor_image"]), np.shape(batch[0]["anchor_image"]))
anchor_images = []
pos_images = []
neg_images = []
anchor_bboxes = []
pos_bboxes = []
neg_bboxes = []
for sample in batch:
anchor_images += [sample["anchor_image"][np.newaxis, :, :, :]]
pos_images += [sample["pos_image"][np.newaxis, :, :, :]]
neg_images += [sample["neg_image"][np.newaxis, :, :, :]]
anchor_bboxes += [sample["anchor_bbox"][np.newaxis, :, :]]
pos_bboxes += [sample["pos_bbox"][np.newaxis, :, :]]
neg_bboxes += [sample["neg_bbox"][np.newaxis, :, :]]
stacked_images = np.concatenate(anchor_images + pos_images + neg_images, axis=0)
stacked_bboxes = np.concatenate(anchor_bboxes + pos_bboxes + neg_bboxes, axis=0)
images = torch.from_numpy(stacked_images).permute(0, 3, 1, 2)
bboxes = torch.from_numpy(stacked_bboxes)
return images, bboxes
if __name__ == "__main__":
"""
Example of creation of dataset object with train and val
"""
from torch.utils.data import DataLoader
from data.transforms import AMVOCTripletTransform
from nnet.layers.text_align_bbox import TextAlign
lectures = AccessMathVOCTriplets.TrainValLectures
v = np.random.randint(0, len(lectures))
train_lectures = lectures[: v] + lectures[v :]
val_lecture = [lectures[v]]
amvoc_triplets_train = AccessMathVOCTriplets('/run1/dataset-txt/AccessMathVOC', AMVOCTripletTransform(),
None, train_lectures, imageset='train')
amvoc_triplets_val = AccessMathVOCTriplets('/run1/dataset-txt/AccessMathVOC', AMVOCTripletTransform(),
None, val_lecture, imageset='val')
resample = TextAlign((1, 3, 30, 30), device='cpu', pool_type=0, rescale=1.0)
dataloader = DataLoader(amvoc_triplets_train, batch_size=5, collate_fn=AccessMathVOCTriplets.detection_collate)
dataiter = iter(dataloader)
for i, (images, boxes) in enumerate(dataiter):
print(images.shape, boxes.shape)
resampled = resample(images.float(), boxes.float())
print(resampled[0].shape, resampled[1])
input()
if i == 2:
break
test_im = torch.tensor(np.random.random(size=(4, 3, 60, 60)))
n_rois = [5, 9, 2, 4]
all_test_coords = []
for n in n_rois:
test_coords = torch.tensor(np.random.random(size=(n, 4)))
test_coords[:, 2] = test_coords[:, 0] + 1.
test_coords[:, 3] = test_coords[:, 1] + 1.
test_coords *= 30.
all_test_coords += [test_coords.float()]
resampled2 = resample(test_im.float(), all_test_coords)
print(resampled2[0].shape, resampled2[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment