Created
December 17, 2020 09:34
-
-
Save escuccim/0de8205b524c7667599f8d0825f58e95 to your computer and use it in GitHub Desktop.
Multi-scale training for PyTorch ImageFolder dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Based on https://github.com/CaoWGG/multi-scale-training""" | |
from torch.utils.data import Sampler,RandomSampler,SequentialSampler | |
import numpy as np | |
class BatchSampler(object): | |
def __init__(self, sampler, batch_size, drop_last,multiscale_step=None,img_sizes = None): | |
if not isinstance(sampler, Sampler): | |
raise ValueError("sampler should be an instance of " | |
"torch.utils.data.Sampler, but got sampler={}" | |
.format(sampler)) | |
if not isinstance(drop_last, bool): | |
raise ValueError("drop_last should be a boolean value, but got " | |
"drop_last={}".format(drop_last)) | |
self.sampler = sampler | |
self.batch_size = batch_size | |
self.drop_last = drop_last | |
if multiscale_step is not None and multiscale_step < 1 : | |
raise ValueError("multiscale_step should be > 0, but got " | |
"multiscale_step={}".format(multiscale_step)) | |
if multiscale_step is not None and img_sizes is None: | |
raise ValueError("img_sizes must a list, but got img_sizes={} ".format(img_sizes)) | |
self.multiscale_step = multiscale_step | |
self.img_sizes = img_sizes | |
def __iter__(self): | |
num_batch = 0 | |
batch = [] | |
size = 416 | |
for idx in self.sampler: | |
batch.append([idx,size]) | |
if len(batch) == self.batch_size: | |
yield batch | |
num_batch+=1 | |
batch = [] | |
if self.multiscale_step and num_batch % self.multiscale_step == 0 : | |
size = np.random.choice(self.img_sizes) | |
if len(batch) > 0 and not self.drop_last: | |
yield batch | |
def __len__(self): | |
if self.drop_last: | |
return len(self.sampler) // self.batch_size | |
else: | |
return (len(self.sampler) + self.batch_size - 1) // self.batch_size | |
class MultiscaleDataSet(torchvision.datasets.ImageFolder): | |
"""Multiscale ImageFolder dataset""" | |
def __getitem__(self, index): | |
if isinstance(index, (tuple, list)): | |
index, input_size = index | |
else: | |
# set the default image size here | |
input_size = 448 | |
path, target = self.samples[index] | |
sample = self.loader(path) | |
# resize the image | |
sample = sample.resize((input_size, input_size)) | |
# return the image and label | |
return sample, target | |
transforms = | |
# create the dataset and loader | |
train_dataset = MultiscaleDataSet( | |
root="data/train", | |
transform=transform | |
) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_sampler=BatchSampler(RandomSampler(train_dataset), | |
batch_size=batch_size, | |
multiscale_step=1, | |
drop_last=True, | |
img_sizes=[320, 384, 448, 512, 576, 640]), | |
num_workers=7, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment