Skip to content

Instantly share code, notes, and snippets.

@Chris-hughes10
Last active December 28, 2021 09:42
Show Gist options
  • Save Chris-hughes10/e88e261f912108bdbd3142cdf24df36c to your computer and use it in GitHub Desktop.
Save Chris-hughes10/e88e261f912108bdbd3142cdf24df36c to your computer and use it in GitHub Desktop.
Effdet_blog_dataset
from torch.utils.data import Dataset
class EfficientDetDataset(Dataset):
def __init__(
self, dataset_adaptor, transforms=get_valid_transforms()
):
self.ds = dataset_adaptor
self.transforms = transforms
def __getitem__(self, index):
(
image,
pascal_bboxes,
class_labels,
image_id,
) = self.ds.get_image_and_labels_by_idx(index)
sample = {
"image": np.array(image, dtype=np.float32),
"bboxes": pascal_bboxes,
"labels": class_labels,
}
sample = self.transforms(**sample)
sample["bboxes"] = np.array(sample["bboxes"])
image = sample["image"]
labels = sample["labels"]
_, new_h, new_w = image.shape
sample["bboxes"][:, [0, 1, 2, 3]] = sample["bboxes"][
:, [1, 0, 3, 2]
] # convert to yxyx
target = {
"bboxes": torch.as_tensor(sample["bboxes"], dtype=torch.float32),
"labels": torch.as_tensor(labels),
"image_id": torch.tensor([image_id]),
"img_size": (new_h, new_w),
"img_scale": torch.tensor([1.0]),
}
return image, target, image_id
def __len__(self):
return len(self.ds)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment