Skip to content

Instantly share code, notes, and snippets.

@roman-4erkasov
Last active February 1, 2023 22:43
Show Gist options
  • Save roman-4erkasov/8b3edaaa34d5f3594b825d814fa3a11a to your computer and use it in GitHub Desktop.
Save roman-4erkasov/8b3edaaa34d5f3594b825d814fa3a11a to your computer and use it in GitHub Desktop.

Рассмотрим пример

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp


class TwoLinLayerNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Linear(10, 10, bias=False)
        self.b = torch.nn.Linear(10, 1, bias=False)

    def forward(self, x):
        a = self.a(x)
        b = self.b(x)
        return (a, b)


def worker(rank):
    dist.init_process_group("gloo", rank=rank, world_size=2)
    torch.cuda.set_device(rank)
    print("init model")
    model = TwoLinLayerNet().cuda()
    print("init ddp")
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

    inp = torch.randn(10, 10).cuda()
    print("train")

    for _ in range(20):
        output = ddp_model(inp)
        loss = output[0] + output[1]
        loss.sum().backward()


if __name__ == "__main__":
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29501"
    os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
    os.environ[
        "TORCH_DISTRIBUTED_DEBUG"
    ] = "DETAIL"  # set to DETAIL for runtime logging.
    mp.spawn(worker, nprocs=2, args=())

Главный процесс

Задается три переменных среды: MASTER_ADDR, MASTER_PORT, TORCH_CPP_LOG_LEVEL и TORCH_DISTRIBUTED_DEBUG

mp.spawn(worker, nprocs=2, args=()) - создает два процесса под управлением интерпретатора python. Вызывает в каждом функцию worker, где rank это индекс созданного процесса. И джойнит их так как аргумент join по умолчанию равен True.

Про spawn в pytorch https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn
Про spawn как один из трех способов мультипроцессинга https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods

Depending on the platform, multiprocessing supports three ways to start a process. These start methods are

  • spawn. The parent process starts a fresh Python interpreter process. The child process will only inherit those resources necessary to run the process object’s run() method. In particular, unnecessary file descriptors and handles from the parent process will not be inherited. Starting a process using this method is rather slow compared to using fork or forkserver. Available on Unix and Windows. The default on Windows and macOS.

  • fork. The parent process uses os.fork() to fork the Python interpreter. The child process, when it begins, is effectively identical to the parent process. All resources of the parent are inherited by the child process. Note that safely forking a multithreaded process is problematic. Available on Unix only. The default on Unix.

  • forkserver. When the program starts and selects the forkserver start method, a server process is started. From then on, whenever a new process is needed, the parent process connects to the server and requests that it fork a new process. The fork server process is single threaded so it is safe for it to use os.fork(). No unnecessary resources are inherited. Available on Unix platforms which support passing file descriptors over Unix pipes.


Подпроцессы: выполнение def worker(rank)

dist.init_process_group("gloo", rank=rank, world_size=2) - добавляет текущий процесс в группу процессов. Проводит инициализирующие действия по данной группе процессов. И блокирует дальнейшее выполнение в текущем процессе пока все процессы группы не буду связаны (This blocks until all processes have joined).

подробнее https://www.baeldung.com/linux/kill-members-process-group

Группа процессов это обект уровня ОС, существует в ОС POSIX-совместимых ОС.

Группа процессов это процессы с одним общим PGID (Process Group ID). Группы процессов позволяют контролировать распредедлие сигналов. Например, достаточно послать один сигнал KILL группе, вместо того чтобы послылать каждому процессу этой группы.

Функция init_process_group определяет как должны комунницировать между собой процессы. Для достижения этого необходимо задать init_process_group одним из двух способов:

  1. "ручной режим". Задать аргументы: store, rank, worlds_size. Где store это общее хранилище для коммуникации между процессами. Это хранилище создается вручную и передается в аргумент store.
  2. "автоматический режим". Задать аргумент init_method (опционально можно задать еще rank и worlds_size). Где init_method это способ автоматического создания общего хранилища для межпроцессной коммуникации.

В аднном примере в init_process_group не задны ни аргумент store, ни аргумент init_method. В таком случае init_method приравнивается значению "env://". Это значит что функция создаст хранилище на основе переменных среды. И как раз ранее были созданы переменные среды MASTER_ADDR и MASTER_PORT. Найдя такие переменные среды фунцкия выберет в качестве хранилища TcpStore, задаст этому хранилищу ip-адрес из MASTER_ADDR и порт из MASTER_PORT. После этого все процессы будут обращаться через tcp по адрессу "tcp://MASTER_ADDR:MASTER_PORT". Поскольку в нашем случае MASTER_ADDR="localhost" а MASTER_PORT="29501", то все процессы будут обращаться по адресу localhost:29501.

ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) создаем обертку вокруг модели, эта обертка синхронизирует буфферы и градиенты копий модели обученной на разных GPU и, как следствие, с разными данными.

Если хранилище для межпроцессной коммуникации будет недоступно, то это может привести к заморозке обучения:

https://cloudblogs.microsoft.com/opensource/2021/08/04/introducing-distributed-data-parallel-support-on-pytorch-windows/ If you’re using TcpStore, make sure the network is accessible and the port is in fact available. Otherwise, the training may freeze because the script fails to initialize the TcpStore. The process with rank zero will bind and listen on the port you provided. Other processes will try to connect to that port. You can use network monitoring tools like “netstat” to help debugging the TCP connection issue.

@roman-4erkasov
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment