This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=seed) | |
| valid_sampler = DistributedSampler(valid_dataset, shuffle=False, seed=seed) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| cur_rank = comm.get_local_rank() | |
| model = DDP(model.to(cur_rank), device_ids=[cur_rank], broadcast_buffers=False) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| launch( | |
| main_func, | |
| num_gpus_per_machine=4, | |
| num_machines=1, | |
| machine_rank=0, | |
| dist_url=‘auto’, | |
| args=() | |
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # copied from detectron2/detectron2/engine/launch.py | |
| # https://github.com/facebookresearch/detectron2/blob/9246ebc3af1c023cfbdae77e5d976edbcf9a2933/detectron2/engine/launch.py | |
| import logging | |
| from datetime import timedelta | |
| import torch | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| from detectron2.utils import comm |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from balanced_sampler import BalancedBatchSchedulerSampler | |
| batch_size = 8 | |
| # dataloader with BalancedBatchSchedulerSampler | |
| dataloader = torch.utils.data.DataLoader(dataset=concat_dataset, | |
| sampler=BalancedBatchSchedulerSampler(dataset=concat_dataset, | |
| batch_size=batch_size), | |
| batch_size=batch_size, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import torch | |
| from torch.utils.data import RandomSampler | |
| from sampler import ImbalancedDatasetSampler | |
| class ExampleImbalancedDatasetSampler(ImbalancedDatasetSampler): | |
| """ | |
| ImbalancedDatasetSampler is taken from: | |
| https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/torchsampler/imbalanced.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from multi_task_batch_scheduler import BatchSchedulerSampler | |
| batch_size = 8 | |
| # dataloader with BatchSchedulerSampler | |
| dataloader = torch.utils.data.DataLoader(dataset=concat_dataset, | |
| sampler=BatchSchedulerSampler(dataset=concat_dataset, | |
| batch_size=batch_size), | |
| batch_size=batch_size, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import torch | |
| from torch.utils.data.sampler import RandomSampler | |
| class BatchSchedulerSampler(torch.utils.data.sampler.Sampler): | |
| """ | |
| iterate over tasks and provide a random batch per task in each mini-batch | |
| """ | |
| def __init__(self, dataset, batch_size): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| batch_size = 8 | |
| # basic dataloader | |
| dataloader = torch.utils.data.DataLoader(dataset=concat_dataset, | |
| batch_size=batch_size, | |
| shuffle=True) | |
| for inputs in dataloader: | |
| print(inputs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch.utils.data.dataset import ConcatDataset | |
| class MyFirstDataset(torch.utils.data.Dataset): | |
| def __init__(self): | |
| # dummy dataset | |
| self.samples = torch.cat((-torch.ones(5), torch.ones(5))) | |
| def __getitem__(self, index): |