Skip to content

Instantly share code, notes, and snippets.

@santurini
Created June 1, 2023 17:09
Show Gist options
  • Save santurini/d04056f68c404f3279abb081afc2b533 to your computer and use it in GitHub Desktop.
Save santurini/d04056f68c404f3279abb081afc2b533 to your computer and use it in GitHub Desktop.
Tutorial to setup a Data and Model Parallel training with FastMoE.

One of the main reason mixture of Experts are gaining so much attention is due to their high degree of parallelization while allowing to scale exponentially the number of parameters. Usually this requires a lot of complex code and deep knowledge of distributed systems but we can get this for free with the FastMoE library.

First of all we need to define our Experts and specify in the expert_dp_comm attribute which type of gradient reduction we would like to use out of:

  • dp: reduced across the data-parallel groups, which means that in the model parallel group, they are not synchronized.
  • world: gradients are synchronized across all workers, regardless their model or data parallel group. This is extremely useful for shared layers like the gate.

Let's define our MoE layer by opting for the synchronization across all workers:

from fmoe.layers import FMoE
from fmoe.linear import FMoELinear

class _Expert(nn.Module):
    def __init__(self, num_expert, d_inp, d_out, activation, rank=0):
        super().__init__()
        self.fc = FMoELinear(num_expert, d_inp, d_out, bias=True, rank=rank)
        self.activation = activation

    def forward(self, inp, fwd_expert_count):
        x = self.activation(self.fc(inp, fwd_expert_count))
        return x
        

class LinearMoE(FMoE):
    def __init__(
        self,
        num_expert=32,
        d_inp=1024,
        d_out=128, 
        activation=torch.nn.ReLU(),
        expert_dp_comm="world",
        expert_rank=0,
        **kwargs
    ):
        def one_expert(d_model):
            return _Expert(1, d_inp, d_out, activation, rank=expert_rank)
        
        expert = one_expert
        super().__init__(num_expert=num_expert, d_model=d_model, expert=expert, **kwargs)
        self.mark_parallel_comm(expert_dp_comm)

    def forward(self, inp: torch.Tensor):
        original_shape = inp.shape
        inp = inp.reshape(-1, self.d_model)
        output = super().forward(inp)
        
        return output.reshape(original_shape)

Now let's plug it in the model by specifying also the world_size (number of workers) and the remembering that the _num_experts refers to the number of LOCAL experts:

from fmoe.gates.gshard_gate import GShardGate

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.moe = LinearMoE(
              num_expert=32,
              d_inp=16 * 5 * 5, 
              d_hidden=120, 
              d_out=84, 
              n_classes=10,
              top_k=2,
              gate=GShardGate,
              world_size=2,
              expert_rank=os.environ.get("OMPI_COMM_WORLD_RANK", 0)
              )

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = self.moe(x)
        return x

Now it's time to setup the distributed environment:

import torch.distributed as dist

rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])

dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

Now wrap the model with the DistributedGroupedDataParallel class and use a Distributed sampler to split the batches across multiple devices:

# setup distributed model
model = Net().cuda()
ddp_model = DistributedGroupedDataParallel(model)

# setup distributed dataloader
train_sampler = DistributedSampler(dataset=train_ds)
train_dl = DataLoader(dataset=train_ds,
                      batch_size=batch_size,
                      sampler=train_sampler,
                      num_workers=num_workers
                      )

Last mandatory step is to perform the allreduce_gradients() call after the loss.backward():

for batch in train_dl:
    . . .
    optimizer.zero_grad()
    loss.backward()
    ddp_model.allreduce_gradients()
    optimizer.step()

To launch the training using OpenMPI you can use this command, in this case we are using two nodes with one GPU per node:

 mpirun -np 2
        -H xxx.xxx.xxx.xxx:1,xxx.xxx.xxx.xxy:1 
        -x MASTER_ADDR=xxx.xxx.xxx.xxx 
        -x MASTER_PORT=1234 
        -x PATH=$PATH:/path/to/venv/my_env/bin 
        -bind-to none -map-by slot 
        python train.py --additional_script_args
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment