Last active
August 3, 2022 03:35
-
-
Save ghwn/a371528e2b5d07177ed0f26e808df4c4 to your computer and use it in GitHub Desktop.
SVHN Dataset Loader
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import cv2 | |
import h5py | |
class SvhnDataset: | |
def __init__(self, data_dir): | |
self.data_dir = data_dir | |
self.file = h5py.File(os.path.join(data_dir, "digitStruct.mat")) | |
self.digit_struct_name = self.file["digitStruct"]["name"] | |
self.digit_struct_bbox = self.file["digitStruct"]["bbox"] | |
def get_name(self, index): | |
name_ref = self.digit_struct_name[index].item() | |
name_ds = self.file[name_ref] | |
name = "".join([chr(name_ds[i].item()) for i in range(name_ds.size)]) | |
return name | |
def get_bbox(self, index, merge=False): | |
bbox_ref = self.digit_struct_bbox[index].item() | |
bbox_group = self.file[bbox_ref] | |
bboxes = [] | |
for i in range(bbox_group["label"].size): | |
bbox = {} | |
for key in bbox_group.keys(): | |
ref = bbox_group[key] | |
something = ref[i].item() | |
if isinstance(something, h5py.h5r.Reference): | |
value = self.file[something][0].item() | |
value = int(value) | |
elif isinstance(something, float): | |
value = int(something) | |
else: | |
raise TypeError | |
bbox[key] = 0 if value == 10 else value | |
bboxes.append(bbox) | |
if merge: | |
tl, br = None, None | |
for box in bboxes: | |
left, top, width, height = [box[key] for key in ["left", "top", "width", "height"]] | |
tl = (left, top) if tl is None else ( | |
min(tl[0], left), | |
min(tl[1], top), | |
) | |
br = (left + width, top + height) if br is None else ( | |
max(br[0], left + width), | |
max(br[1], top + height), | |
) | |
label = "".join([str(box["label"]) for box in bboxes]) | |
return { | |
"height": br[1] - tl[1], | |
"label": label, | |
"left": tl[0], | |
"top": tl[1], | |
"width": br[0] - tl[0], | |
} | |
return bboxes | |
def __getitem__(self, key): | |
def get_item(index): | |
bbox = self.get_bbox(index, merge=True) | |
x, y, w, h = [bbox[key] for key in ["left", "top", "width", "height"]] | |
label = bbox["label"] | |
label = list(map(int, label)) | |
for i in range(5 - len(label)): | |
label.append(10) | |
filename = self.get_name(index) | |
image_path = os.path.join(self.data_dir, filename) | |
image = cv2.imread(image_path, cv2.IMREAD_COLOR) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image = image[y:y+h, x:x+w] | |
area = image.shape[0] * image.shape[1] | |
image = cv2.resize(image, (64, 64), interpolation=cv2.INTER_AREA if area > 64 * 64 else cv2.INTER_LINEAR) | |
return image, label | |
if isinstance(key, int): | |
return get_item(key) | |
elif isinstance(key, slice): | |
return [get_item(index) for index in range(*key.indices(len(self)))] | |
def __len__(self): | |
return self.digit_struct_name.size | |
def __del__(self): | |
self.file.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment