Skip to content

Instantly share code, notes, and snippets.

@ResidentMario
Created June 10, 2020 01:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ResidentMario/c852b0b1210db207d037f5b5cf5f100f to your computer and use it in GitHub Desktop.
Save ResidentMario/c852b0b1210db207d037f5b5cf5f100f to your computer and use it in GitHub Desktop.
# NEW additional imports
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
# NEW init_process method
def init_process(rank, size, backend='gloo'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
def get_dataloader(rank, world_size):
dataset = TrainingDataset()
# NEW disributed data loader
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size)
dataloader = DataLoader(dataset, batch_size=8, shuffle=False, sampler=sampler)
return dataloader
def train(rank, num_epochs, world_size):
# NEW training process init
init_process(rank, world_size)
print(f"Rank {rank + 1}/{world_size} training process initialized.\n")
# NEW downloading data
if rank == 0:
get_dataloader(rank, world_size)
get_model()
dist.barrier()
print(f"Rank {rank + 1}/{world_size} training process passed data download barrier.\n")
model = get_model()
model.cuda(rank)
model.train()
# NEW where the magic happens ✨
model = DistributedDataParallel(model, device_ids=[rank])
dataloader = get_dataloader(rank, world_size)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())
for epoch in range(1, NUM_EPOCHS + 1):
for i, (batch, segmap) in enumerate(dataloader):
# NEW go to specific
batch = batch.cuda(rank)
segmap = segmap.cuda(rank)
# ...training code goes here...
if rank == 0:
torch.save(model.state_dict(), f'/spell/checkpoints/model_{epoch}.pth')
WORLD_SIZE = torch.cuda.device_count()
if __name__=="__main__":
mp.spawn(train, args=(NUM_EPOCHS, WORLD_SIZE), nprocs=WORLD_SIZE, join=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment