Skip to content

Instantly share code, notes, and snippets.

@cedrickchee
Created December 25, 2018 17:06
Show Gist options
  • Save cedrickchee/72613b7282d22a3aa72a2957a629e9c3 to your computer and use it in GitHub Desktop.
Save cedrickchee/72613b7282d22a3aa72a2957a629e9c3 to your computer and use it in GitHub Desktop.
Setup PyTorch 1.0 stable distributed training

distrib_train function does all setup required for distributed training:

import torch

def distrib_train(gpu):
    if gpu is None: return gpu
    gpu = int(gpu)
    torch.cuda.set_device(int(gpu))
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    return gpu

torch.distributed provides an MPI-like interface for exchanging tensor data across multi-machine networks. It supports a few different backends and initialization methods.

We are using NVIDIA Collective Communications Library (NCCL) for the backend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment