One of the main reason mixture of Experts are gaining so much attention is due to their high degree of parallelization while allowing to scale exponentially the number of parameters. Usually this requires a lot of complex code and deep knowledge of distributed systems but we can get this for free with the FastMoE library.
First of all we need to define our Experts and specify in the expert_dp_comm attribute which type of gradient reduction we would like to use out of:
- dp: reduced across the data-parallel groups, which means that in the model parallel group, they are not synchronized.
- world: gradients are synchronized across all workers, regardless their model or data parallel group. This is extremely useful for shared layers like the gate.
Let's define our MoE layer by opting for the synchronization across all workers:
from fmoe.layers import FMoE
from fmoe.linear import FMoELinear
class _Expert(nn.Module):
def __init__(self, num_expert, d_inp, d_out, activation, rank=0):
super().__init__()
self.fc = FMoELinear(num_expert, d_inp, d_out, bias=True, rank=rank)
self.activation = activation
def forward(self, inp, fwd_expert_count):
x = self.activation(self.fc(inp, fwd_expert_count))
return x
class LinearMoE(FMoE):
def __init__(
self,
num_expert=32,
d_inp=1024,
d_out=128,
activation=torch.nn.ReLU(),
expert_dp_comm="world",
expert_rank=0,
**kwargs
):
def one_expert(d_model):
return _Expert(1, d_inp, d_out, activation, rank=expert_rank)
expert = one_expert
super().__init__(num_expert=num_expert, d_model=d_model, expert=expert, **kwargs)
self.mark_parallel_comm(expert_dp_comm)
def forward(self, inp: torch.Tensor):
original_shape = inp.shape
inp = inp.reshape(-1, self.d_model)
output = super().forward(inp)
return output.reshape(original_shape)
Now let's plug it in the model by specifying also the world_size (number of workers) and the remembering that the _num_experts refers to the number of LOCAL experts:
from fmoe.gates.gshard_gate import GShardGate
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.moe = LinearMoE(
num_expert=32,
d_inp=16 * 5 * 5,
d_hidden=120,
d_out=84,
n_classes=10,
top_k=2,
gate=GShardGate,
world_size=2,
expert_rank=os.environ.get("OMPI_COMM_WORLD_RANK", 0)
)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = self.moe(x)
return x
Now it's time to setup the distributed environment:
import torch.distributed as dist
rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
Now wrap the model with the DistributedGroupedDataParallel class and use a Distributed sampler to split the batches across multiple devices:
# setup distributed model
model = Net().cuda()
ddp_model = DistributedGroupedDataParallel(model)
# setup distributed dataloader
train_sampler = DistributedSampler(dataset=train_ds)
train_dl = DataLoader(dataset=train_ds,
batch_size=batch_size,
sampler=train_sampler,
num_workers=num_workers
)
Last mandatory step is to perform the allreduce_gradients() call after the loss.backward():
for batch in train_dl:
. . .
optimizer.zero_grad()
loss.backward()
ddp_model.allreduce_gradients()
optimizer.step()
To launch the training using OpenMPI you can use this command, in this case we are using two nodes with one GPU per node:
mpirun -np 2
-H xxx.xxx.xxx.xxx:1,xxx.xxx.xxx.xxy:1
-x MASTER_ADDR=xxx.xxx.xxx.xxx
-x MASTER_PORT=1234
-x PATH=$PATH:/path/to/venv/my_env/bin
-bind-to none -map-by slot
python train.py --additional_script_args