Skip to content

Instantly share code, notes, and snippets.

@santurini
Last active June 1, 2023 17:09
Show Gist options
  • Save santurini/844f645a0af4bf8d694b74c00d709241 to your computer and use it in GitHub Desktop.
Save santurini/844f645a0af4bf8d694b74c00d709241 to your computer and use it in GitHub Desktop.
Simple Tutorial to get started with the FastMoE library.

In this tutorial we are going to consider a simple model in which we are going to replace the MLP with a MoE. The starting model is defined like this:

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120) # from fc1 to fc3 is our MLP to be replaced
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

To moefy this model we first need to define our expert such that it matches the original one but with the addition of a MoE at each layer:

from fmoe.linear import FMoELinear

class _Expert(nn.Module):
    def __init__(self, num_expert, d_inp, d_hidden, d_out, n_classes, activation, rank=0):
        super().__init__()
        self.fc1 = FMoELinear(num_expert, d_inp, d_hidden, bias=True, rank=rank)
        self.fc2 = FMoELinear(num_expert, d_hidden, d_out, bias=True, rank=rank)
        self.fc3 = FMoELinear(num_expert, d_out, n_classes, bias=True, rank=rank)
        self.activation = activation

    def forward(self, inp, fwd_expert_count):
        x = self.activation(self.fc1(inp, fwd_expert_count))
        x = self.activation(self.fc2(x, fwd_expert_count))
        x = self.fc3(x, fwd_expert_count)
        return x

Now we can define our custom MoE Layer by subclassing the FMoE class:

from fmoe.layers import FMoE

class LinearMoE(FMoE):
    def __init__(
        self,
        num_expert=32,
        d_inp=16 * 5 * 5, 
        d_hidden=120, 
        d_out=84, 
        n_classes=10,
        activation=torch.nn.ReLU(),
        expert_dp_comm="none",
        expert_rank=0,
        **kwargs
    ):
        def one_expert(d_model):
            return _Expert(1, d_inp, d_hidden, d_out, n_classes, activation, rank=0)
        
        expert = one_expert
        super().__init__(num_expert=num_expert, d_model=d_model, expert=expert, **kwargs)
        self.mark_parallel_comm(expert_dp_comm)

    def forward(self, inp: torch.Tensor):
        original_shape = inp.shape
        inp = inp.reshape(-1, self.d_model)
        output = super().forward(inp)
        
        return output.reshape(original_shape)

Finally we can modify the original model by inserting the MoE instead of the linear layers. Remember that the FMoE class has some very useful additional parameters, for instance:

  • top_k: either 1 or 2, the number of experts selected for the routing
  • gate: a gate class for the routing mechanism, the available gates can be found here (gates)
from fmoe.gates.gshard_gate import GShardGate

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.moe = LinearMoE(
              num_expert=32,
              d_inp=16 * 5 * 5, 
              d_hidden=120, 
              d_out=84, 
              n_classes=10,
              top_k=2,
              gate=GShardGate
              )

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = self.moe(x)
        return x

That's it, this are the only modifications needed in order to get started with FastMoE, no other modifications are required neither in the train loop (as long as we are on single GPU) nor in the architecture.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment