Skip to content

Instantly share code, notes, and snippets.

@dlibenzi
Created May 20, 2020 14:13
Show Gist options
  • Save dlibenzi/48c070ddd727b347b8be2accc6afc377 to your computer and use it in GitHub Desktop.
Save dlibenzi/48c070ddd727b347b8be2accc6afc377 to your computer and use it in GitHub Desktop.
import torch_xla.distributed.xla_multiprocessing as xmp
SERIAL_EXEC = xmp.MpSerialExecutor()
def _mp_fn(_):
def _serial_fn():
import time
print(f'rank {xm.get_ordinal()} start at {time.time()}')
time.sleep(5)
print(f'rank {xm.get_ordinal()} done at {time.time()}')
SERIAL_EXEC.run(_serial_fn)
xmp.spawn(_mp_fn, nprocs=8, start_method='fork')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment