Skip to content

Instantly share code, notes, and snippets.

Created December 17, 2020 09:34
Show Gist options
  • Save escuccim/0de8205b524c7667599f8d0825f58e95 to your computer and use it in GitHub Desktop.
Save escuccim/0de8205b524c7667599f8d0825f58e95 to your computer and use it in GitHub Desktop.
Multi-scale training for PyTorch ImageFolder dataset
"""Based on"""
from import Sampler,RandomSampler,SequentialSampler
import numpy as np
class BatchSampler(object):
def __init__(self, sampler, batch_size, drop_last,multiscale_step=None,img_sizes = None):
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
", but got sampler={}"
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
if multiscale_step is not None and multiscale_step < 1 :
raise ValueError("multiscale_step should be > 0, but got "
if multiscale_step is not None and img_sizes is None:
raise ValueError("img_sizes must a list, but got img_sizes={} ".format(img_sizes))
self.multiscale_step = multiscale_step
self.img_sizes = img_sizes
def __iter__(self):
num_batch = 0
batch = []
size = 416
for idx in self.sampler:
if len(batch) == self.batch_size:
yield batch
batch = []
if self.multiscale_step and num_batch % self.multiscale_step == 0 :
size = np.random.choice(self.img_sizes)
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
class MultiscaleDataSet(torchvision.datasets.ImageFolder):
"""Multiscale ImageFolder dataset"""
def __getitem__(self, index):
if isinstance(index, (tuple, list)):
index, input_size = index
# set the default image size here
input_size = 448
path, target = self.samples[index]
sample = self.loader(path)
# resize the image
sample = sample.resize((input_size, input_size))
# return the image and label
return sample, target
transforms =
# create the dataset and loader
train_dataset = MultiscaleDataSet(
train_loader =
img_sizes=[320, 384, 448, 512, 576, 640]),
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment