Skip to content

Instantly share code, notes, and snippets.

@giacaglia
Created December 8, 2019 01:07
Show Gist options
  • Save giacaglia/a61f8fc02a3c1d2bb4393bbbc0c7be8d to your computer and use it in GitHub Desktop.
Save giacaglia/a61f8fc02a3c1d2bb4393bbbc0c7be8d to your computer and use it in GitHub Desktop.
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-n', '--nodes', default=1,
type=int, metavar='N')
parser.add_argument('-g', '--gpus', default=1, type=int,
help='number of gpus per node')
parser.add_argument('-nr', '--nr', default=0, type=int,
help='ranking within the nodes')
parser.add_argument('--epochs', default=2, type=int,
metavar='N',
help='number of total epochs to run')
args = parser.parse_args()
#########################################################
args.world_size = args.gpus * args.nodes #
os.environ['MASTER_ADDR'] = '10.57.23.164' #
os.environ['MASTER_PORT'] = '8888' #
mp.spawn(train, nprocs=args.gpus, args=(args,)) #
#########################################################
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment