Created
February 9, 2024 01:41
-
-
Save ruotianluo/6e580537f902c465f308cf76fcbcda9e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class InfiniteConcatDistributedSampler(DistributedSampler): | |
def __init__(self, *args, **kwargs): | |
""" | |
Args: | |
global_batch_size: since infinite indices will wrap, | |
so it is possible that same images in one batch. | |
We apply drop_last here in the sampler. | |
determistic: we always start the generator with seed 0, | |
and then to restart from certain iteration, we just | |
manually drain the iterator. This can make dataloader | |
properly resumed. | |
""" | |
self.ratio = kwargs.get('ratio', None) | |
if 'ratio' in kwargs: | |
del kwargs['ratio'] | |
_args = args[:] | |
_kwargs = kwargs.copy() | |
self.global_batch_size = kwargs.get('global_batch_size', -1) | |
if 'global_batch_size' in kwargs: | |
del kwargs['global_batch_size'] | |
self.deterministic = kwargs.get('deterministic', False) | |
if 'deterministic' in kwargs: | |
del kwargs['deterministic'] | |
assert self.global_batch_size != -1, 'have to know global_batch_size' | |
assert self.deterministic, 'not support non deterministic' | |
super().__init__(*args, **kwargs) | |
assert type(self.dataset) is torch.utils.data.ConcatDataset | |
if self.ratio is None: | |
self.ratio = torch.tensor([1] * len(self.dataset.datasets)).float() | |
else: | |
self.ratio = torch.tensor(self.ratio).float() | |
self.sub_samplers = [] | |
for ds in self.dataset.datasets: | |
if 'dataset' in _kwargs: | |
_kwargs['dataset'] = ds | |
else: | |
_args = [ds] + list(_args[1:]) | |
self.sub_samplers.append(InfiniteDistributedSampler(*_args, **_kwargs)) | |
# Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/samplers/distributed_sampler.py#L12 | |
def __iter__(self): | |
if self.deterministic: | |
assert self.global_batch_size != -1 | |
for _i, idx in enumerate(self._infinite_indices()): | |
if _i >= self.epoch * self.global_batch_size // self.num_replicas: # Make sure the epoch is actually iteration | |
yield idx | |
else: | |
yield from self._infinite_indices() | |
def set_epoch(self, epoch): | |
# [_.set_epoch(epoch) for _ in self.sub_samplers] # we don't need to set them, because other wise sub_sampler will drain too. | |
super().set_epoch(epoch) | |
def _infinite_indices(self): | |
g = torch.Generator() | |
if self.deterministic: | |
g.manual_seed(0) | |
else: | |
g.manual_seed(self.epoch) | |
iterators = [iter(_) for _ in self.sub_samplers] | |
while True: | |
# select dataset | |
dataset_idx = torch.multinomial(self.ratio, 1, generator=g).item() | |
# yield a batch | |
for _i in range(self.global_batch_size // self.num_replicas): | |
idx = next(iterators[dataset_idx]) | |
idx = self.dataset.cumulative_sizes[dataset_idx] - idx - 1 | |
yield idx |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment