Skip to content

Instantly share code, notes, and snippets.

train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=seed)
valid_sampler = DistributedSampler(valid_dataset, shuffle=False, seed=seed)
from torch.nn.parallel import DistributedDataParallel as DDP
cur_rank = comm.get_local_rank()
model = DDP(model.to(cur_rank), device_ids=[cur_rank], broadcast_buffers=False)
launch(
main_func,
num_gpus_per_machine=4,
num_machines=1,
machine_rank=0,
dist_url=‘auto’,
args=()
)
# Copyright (c) Facebook, Inc. and its affiliates.
# copied from detectron2/detectron2/engine/launch.py
# https://github.com/facebookresearch/detectron2/blob/9246ebc3af1c023cfbdae77e5d976edbcf9a2933/detectron2/engine/launch.py
import logging
from datetime import timedelta
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from detectron2.utils import comm
@bomri
bomri / balanced_batch_scheduler_dataloader_example.py
Last active January 2, 2020 21:05
Unbalanced data loading for multi-task learning in PyTorch (6)
import torch
from balanced_sampler import BalancedBatchSchedulerSampler
batch_size = 8
# dataloader with BalancedBatchSchedulerSampler
dataloader = torch.utils.data.DataLoader(dataset=concat_dataset,
sampler=BalancedBatchSchedulerSampler(dataset=concat_dataset,
batch_size=batch_size),
batch_size=batch_size,
@bomri
bomri / balanced_sampler.py
Last active May 30, 2023 07:24
Unbalanced data loading for multi-task learning in PyTorch (5)
import math
import torch
from torch.utils.data import RandomSampler
from sampler import ImbalancedDatasetSampler
class ExampleImbalancedDatasetSampler(ImbalancedDatasetSampler):
"""
ImbalancedDatasetSampler is taken from:
https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/torchsampler/imbalanced.py
@bomri
bomri / batch_scheduler_dataloader_example.py
Last active January 2, 2020 21:06
Unbalanced data loading for multi-task learning in PyTorch (4)
import torch
from multi_task_batch_scheduler import BatchSchedulerSampler
batch_size = 8
# dataloader with BatchSchedulerSampler
dataloader = torch.utils.data.DataLoader(dataset=concat_dataset,
sampler=BatchSchedulerSampler(dataset=concat_dataset,
batch_size=batch_size),
batch_size=batch_size,
@bomri
bomri / multi_task_batch_scheduler.py
Last active May 30, 2023 07:23
Unbalanced data loading for multi-task learning in PyTorch (3)
import math
import torch
from torch.utils.data.sampler import RandomSampler
class BatchSchedulerSampler(torch.utils.data.sampler.Sampler):
"""
iterate over tasks and provide a random batch per task in each mini-batch
"""
def __init__(self, dataset, batch_size):
@bomri
bomri / basic_dataloader_example.py
Last active January 2, 2020 21:06
Unbalanced data loading for multi-task learning in PyTorch (2)
batch_size = 8
# basic dataloader
dataloader = torch.utils.data.DataLoader(dataset=concat_dataset,
batch_size=batch_size,
shuffle=True)
for inputs in dataloader:
print(inputs)
@bomri
bomri / basic_dataset_example.py
Last active January 2, 2020 21:06
Unbalanced data loading for multi-task learning in PyTorch (1)
import torch
from torch.utils.data.dataset import ConcatDataset
class MyFirstDataset(torch.utils.data.Dataset):
def __init__(self):
# dummy dataset
self.samples = torch.cat((-torch.ones(5), torch.ones(5)))
def __getitem__(self, index):