Skip to content

Instantly share code, notes, and snippets.

@escuccim
Created December 17, 2020 09:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save escuccim/0de8205b524c7667599f8d0825f58e95 to your computer and use it in GitHub Desktop.
Save escuccim/0de8205b524c7667599f8d0825f58e95 to your computer and use it in GitHub Desktop.
Multi-scale training for PyTorch ImageFolder dataset
"""Based on https://github.com/CaoWGG/multi-scale-training"""
from torch.utils.data import Sampler,RandomSampler,SequentialSampler
import numpy as np
class BatchSampler(object):
def __init__(self, sampler, batch_size, drop_last,multiscale_step=None,img_sizes = None):
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.format(sampler))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
if multiscale_step is not None and multiscale_step < 1 :
raise ValueError("multiscale_step should be > 0, but got "
"multiscale_step={}".format(multiscale_step))
if multiscale_step is not None and img_sizes is None:
raise ValueError("img_sizes must a list, but got img_sizes={} ".format(img_sizes))
self.multiscale_step = multiscale_step
self.img_sizes = img_sizes
def __iter__(self):
num_batch = 0
batch = []
size = 416
for idx in self.sampler:
batch.append([idx,size])
if len(batch) == self.batch_size:
yield batch
num_batch+=1
batch = []
if self.multiscale_step and num_batch % self.multiscale_step == 0 :
size = np.random.choice(self.img_sizes)
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
class MultiscaleDataSet(torchvision.datasets.ImageFolder):
"""Multiscale ImageFolder dataset"""
def __getitem__(self, index):
if isinstance(index, (tuple, list)):
index, input_size = index
else:
# set the default image size here
input_size = 448
path, target = self.samples[index]
sample = self.loader(path)
# resize the image
sample = sample.resize((input_size, input_size))
# return the image and label
return sample, target
transforms =
# create the dataset and loader
train_dataset = MultiscaleDataSet(
root="data/train",
transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=BatchSampler(RandomSampler(train_dataset),
batch_size=batch_size,
multiscale_step=1,
drop_last=True,
img_sizes=[320, 384, 448, 512, 576, 640]),
num_workers=7,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment