Created
May 9, 2024 20:08
-
-
Save ruslanmv/ea778a734c791420af60e4b539500705 to your computer and use it in GitHub Desktop.
Parallel execution juyter notebook pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.multiprocessing as mp | |
# Distributed training setup | |
def init_distributed(rank, world_size): | |
os.environ["MASTER_ADDR"] = "localhost" | |
os.environ["MASTER_PORT"] = "12345" | |
if rank == 0: | |
print("Initializing distributed process group...") | |
torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank) | |
def cleanup_distributed(): | |
torch.distributed.destroy_process_group() | |
def main_worker(rank, world_size): | |
init_distributed(rank, world_size) | |
# Your model training and fine-tuning code goes here | |
cleanup_distributed() | |
if __name__ == "__main__": | |
world_size = torch.cuda.device_count() | |
# Workaround for Jupyter Notebook and interactive environments | |
processes = [] | |
for rank in range(world_size): | |
p = mp.Process(target=main_worker, args=(rank, world_size)) | |
p.start() | |
processes.append(p) | |
for p in processes: | |
p.join() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment