Created
June 10, 2020 01:43
-
-
Save ResidentMario/c852b0b1210db207d037f5b5cf5f100f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# NEW additional imports | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
from torch.nn.parallel import DistributedDataParallel | |
from torch.utils.data.distributed import DistributedSampler | |
# NEW init_process method | |
def init_process(rank, size, backend='gloo'): | |
""" Initialize the distributed environment. """ | |
os.environ['MASTER_ADDR'] = '127.0.0.1' | |
os.environ['MASTER_PORT'] = '29500' | |
dist.init_process_group(backend, rank=rank, world_size=size) | |
def get_dataloader(rank, world_size): | |
dataset = TrainingDataset() | |
# NEW disributed data loader | |
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size) | |
dataloader = DataLoader(dataset, batch_size=8, shuffle=False, sampler=sampler) | |
return dataloader | |
def train(rank, num_epochs, world_size): | |
# NEW training process init | |
init_process(rank, world_size) | |
print(f"Rank {rank + 1}/{world_size} training process initialized.\n") | |
# NEW downloading data | |
if rank == 0: | |
get_dataloader(rank, world_size) | |
get_model() | |
dist.barrier() | |
print(f"Rank {rank + 1}/{world_size} training process passed data download barrier.\n") | |
model = get_model() | |
model.cuda(rank) | |
model.train() | |
# NEW where the magic happens ✨ | |
model = DistributedDataParallel(model, device_ids=[rank]) | |
dataloader = get_dataloader(rank, world_size) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = Adam(model.parameters()) | |
for epoch in range(1, NUM_EPOCHS + 1): | |
for i, (batch, segmap) in enumerate(dataloader): | |
# NEW go to specific | |
batch = batch.cuda(rank) | |
segmap = segmap.cuda(rank) | |
# ...training code goes here... | |
if rank == 0: | |
torch.save(model.state_dict(), f'/spell/checkpoints/model_{epoch}.pth') | |
WORLD_SIZE = torch.cuda.device_count() | |
if __name__=="__main__": | |
mp.spawn(train, args=(NUM_EPOCHS, WORLD_SIZE), nprocs=WORLD_SIZE, join=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment