Skip to content

Instantly share code, notes, and snippets.

@ghwn
Last active August 3, 2022 03:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ghwn/a371528e2b5d07177ed0f26e808df4c4 to your computer and use it in GitHub Desktop.
Save ghwn/a371528e2b5d07177ed0f26e808df4c4 to your computer and use it in GitHub Desktop.
SVHN Dataset Loader
import os
import cv2
import h5py
class SvhnDataset:
def __init__(self, data_dir):
self.data_dir = data_dir
self.file = h5py.File(os.path.join(data_dir, "digitStruct.mat"))
self.digit_struct_name = self.file["digitStruct"]["name"]
self.digit_struct_bbox = self.file["digitStruct"]["bbox"]
def get_name(self, index):
name_ref = self.digit_struct_name[index].item()
name_ds = self.file[name_ref]
name = "".join([chr(name_ds[i].item()) for i in range(name_ds.size)])
return name
def get_bbox(self, index, merge=False):
bbox_ref = self.digit_struct_bbox[index].item()
bbox_group = self.file[bbox_ref]
bboxes = []
for i in range(bbox_group["label"].size):
bbox = {}
for key in bbox_group.keys():
ref = bbox_group[key]
something = ref[i].item()
if isinstance(something, h5py.h5r.Reference):
value = self.file[something][0].item()
value = int(value)
elif isinstance(something, float):
value = int(something)
else:
raise TypeError
bbox[key] = 0 if value == 10 else value
bboxes.append(bbox)
if merge:
tl, br = None, None
for box in bboxes:
left, top, width, height = [box[key] for key in ["left", "top", "width", "height"]]
tl = (left, top) if tl is None else (
min(tl[0], left),
min(tl[1], top),
)
br = (left + width, top + height) if br is None else (
max(br[0], left + width),
max(br[1], top + height),
)
label = "".join([str(box["label"]) for box in bboxes])
return {
"height": br[1] - tl[1],
"label": label,
"left": tl[0],
"top": tl[1],
"width": br[0] - tl[0],
}
return bboxes
def __getitem__(self, key):
def get_item(index):
bbox = self.get_bbox(index, merge=True)
x, y, w, h = [bbox[key] for key in ["left", "top", "width", "height"]]
label = bbox["label"]
label = list(map(int, label))
for i in range(5 - len(label)):
label.append(10)
filename = self.get_name(index)
image_path = os.path.join(self.data_dir, filename)
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image[y:y+h, x:x+w]
area = image.shape[0] * image.shape[1]
image = cv2.resize(image, (64, 64), interpolation=cv2.INTER_AREA if area > 64 * 64 else cv2.INTER_LINEAR)
return image, label
if isinstance(key, int):
return get_item(key)
elif isinstance(key, slice):
return [get_item(index) for index in range(*key.indices(len(self)))]
def __len__(self):
return self.digit_struct_name.size
def __del__(self):
self.file.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment