Skip to content

Instantly share code, notes, and snippets.

@ruslanmv
Created May 9, 2024 20:08
Show Gist options
  • Save ruslanmv/ea778a734c791420af60e4b539500705 to your computer and use it in GitHub Desktop.
Save ruslanmv/ea778a734c791420af60e4b539500705 to your computer and use it in GitHub Desktop.
Parallel execution juyter notebook pytorch
import os
import torch
import torch.multiprocessing as mp
# Distributed training setup
def init_distributed(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"
if rank == 0:
print("Initializing distributed process group...")
torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank)
def cleanup_distributed():
torch.distributed.destroy_process_group()
def main_worker(rank, world_size):
init_distributed(rank, world_size)
# Your model training and fine-tuning code goes here
cleanup_distributed()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
# Workaround for Jupyter Notebook and interactive environments
processes = []
for rank in range(world_size):
p = mp.Process(target=main_worker, args=(rank, world_size))
p.start()
processes.append(p)
for p in processes:
p.join()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment