Skip to content

Instantly share code, notes, and snippets.

@C1rN09
Created March 17, 2023 15:24
Show Gist options
  • Save C1rN09/051da40e4689db7689f0e2f9f7b2d546 to your computer and use it in GitHub Desktop.
Save C1rN09/051da40e4689db7689f0e2f9f7b2d546 to your computer and use it in GitHub Desktop.
import os
import subprocess
import torch
import sys
local_rank = os.environ['SLURM_LOCALID']
nnodes = os.environ['SLURM_JOB_NUM_NODES']
node_id = os.environ['SLURM_NODEID']
node_list = os.environ['SLURM_JOB_NODELIST']
hostname = subprocess.getoutput(
f'scontrol show hostname "{node_list}" | head -n1')
port = os.environ.get('PORT') or '27182'
def run_torchrun():
cmd = [
'torchrun',
'--nnodes', nnodes,
'--nproc_per_node', str(torch.cuda.device_count()),
'--node_rank', node_id,
'--master_addr', hostname,
'--master_port', port,
*sys.argv[1:]
]
cmd_text = ' '.join(cmd)
if local_rank == '0':
curhost = subprocess.getoutput('hostname')
print(curhost, cmd_text)
subprocess.check_call(cmd)
if __name__ == '__main__':
run_torchrun()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment