Skip to content

Instantly share code, notes, and snippets.

@ruotianluo
Created February 9, 2024 01:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ruotianluo/6e580537f902c465f308cf76fcbcda9e to your computer and use it in GitHub Desktop.
Save ruotianluo/6e580537f902c465f308cf76fcbcda9e to your computer and use it in GitHub Desktop.
class InfiniteConcatDistributedSampler(DistributedSampler):
def __init__(self, *args, **kwargs):
"""
Args:
global_batch_size: since infinite indices will wrap,
so it is possible that same images in one batch.
We apply drop_last here in the sampler.
determistic: we always start the generator with seed 0,
and then to restart from certain iteration, we just
manually drain the iterator. This can make dataloader
properly resumed.
"""
self.ratio = kwargs.get('ratio', None)
if 'ratio' in kwargs:
del kwargs['ratio']
_args = args[:]
_kwargs = kwargs.copy()
self.global_batch_size = kwargs.get('global_batch_size', -1)
if 'global_batch_size' in kwargs:
del kwargs['global_batch_size']
self.deterministic = kwargs.get('deterministic', False)
if 'deterministic' in kwargs:
del kwargs['deterministic']
assert self.global_batch_size != -1, 'have to know global_batch_size'
assert self.deterministic, 'not support non deterministic'
super().__init__(*args, **kwargs)
assert type(self.dataset) is torch.utils.data.ConcatDataset
if self.ratio is None:
self.ratio = torch.tensor([1] * len(self.dataset.datasets)).float()
else:
self.ratio = torch.tensor(self.ratio).float()
self.sub_samplers = []
for ds in self.dataset.datasets:
if 'dataset' in _kwargs:
_kwargs['dataset'] = ds
else:
_args = [ds] + list(_args[1:])
self.sub_samplers.append(InfiniteDistributedSampler(*_args, **_kwargs))
# Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/samplers/distributed_sampler.py#L12
def __iter__(self):
if self.deterministic:
assert self.global_batch_size != -1
for _i, idx in enumerate(self._infinite_indices()):
if _i >= self.epoch * self.global_batch_size // self.num_replicas: # Make sure the epoch is actually iteration
yield idx
else:
yield from self._infinite_indices()
def set_epoch(self, epoch):
# [_.set_epoch(epoch) for _ in self.sub_samplers] # we don't need to set them, because other wise sub_sampler will drain too.
super().set_epoch(epoch)
def _infinite_indices(self):
g = torch.Generator()
if self.deterministic:
g.manual_seed(0)
else:
g.manual_seed(self.epoch)
iterators = [iter(_) for _ in self.sub_samplers]
while True:
# select dataset
dataset_idx = torch.multinomial(self.ratio, 1, generator=g).item()
# yield a batch
for _i in range(self.global_batch_size // self.num_replicas):
idx = next(iterators[dataset_idx])
idx = self.dataset.cumulative_sizes[dataset_idx] - idx - 1
yield idx
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment