Skip to content

Instantly share code, notes, and snippets.

@ruvnet
Last active June 4, 2024 17:23
Show Gist options
  • Save ruvnet/0928768dd1e4af8816e31dde0a0205d5 to your computer and use it in GitHub Desktop.
Save ruvnet/0928768dd1e4af8816e31dde0a0205d5 to your computer and use it in GitHub Desktop.
A PyTorch implementation of a Mixture of Experts (MoE) model resembling the Mixtral 8x7B architecture, with detailed inline comments. This model combines transformer layers with an MoE layer consisting of 8 experts, aiming for high efficiency by activating only 2 experts per token. It's configured with dimensions reflecting the operational effic…
"""
This model integrates the MoE concept within a Transformer architecture. Each token's
representation is processed by a subset of experts, determined by the gating mechanism.
This architecture allows for efficient and specialized handling of different aspects of the
data, aiming for the adaptability and efficiency noted in the Mixtral 8x7B model's design
philosophy. The model activates only a fraction of the available experts for each token,
significantly reducing the computational resources needed compared to activating all experts
for all tokens.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# Define the Expert class
class Expert(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(Expert, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x)
# Define the Gating Network class
class GatingNetwork(nn.Module):
def __init__(self, input_dim, num_experts):
super(GatingNetwork, self).__init__()
self.gate = nn.Linear(input_dim, num_experts)
def forward(self, x):
return F.softmax(self.gate(x), dim=2)
# Define the Mixture of Experts Layer class
class MoELayer(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_experts):
super(MoELayer, self).__init__()
self.experts = nn.ModuleList([Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)])
self.gate = GatingNetwork(input_dim, num_experts)
def forward(self, x, num_experts_per_tok):
gating_scores = self.gate(x)
topk_gating_scores, topk_indices = gating_scores.topk(num_experts_per_tok, dim=2, sorted=False)
# Create a mask to zero out the contributions of non-topk experts
mask = torch.zeros_like(gating_scores).scatter_(2, topk_indices, 1)
# Use the mask to retain only the topk gating scores
gating_scores = gating_scores * mask
# Normalize the gating scores to sum to 1 across the selected top experts
gating_scores = F.normalize(gating_scores, p=1, dim=2)
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
expert_outputs = expert_outputs.transpose(1, 2)
output = torch.einsum('bte,bteo->bto', gating_scores, expert_outputs)
return output
# Define the overall Transformer model with integrated MoE
class TransformerWithMoE(nn.Module):
def __init__(self, num_layers, dim, head_dim, hidden_dim, n_heads, num_experts, vocab_size, num_experts_per_tok):
super(TransformerWithMoE, self).__init__()
self.num_experts_per_tok = num_experts_per_tok
self.embedding = nn.Embedding(vocab_size, dim)
self.layers = nn.ModuleList([nn.TransformerEncoderLayer(d_model=dim, nhead=n_heads) for _ in range(num_layers)])
self.moe_layer = MoELayer(dim, hidden_dim, dim, num_experts)
self.output_layer = nn.Linear(dim, vocab_size)
def forward(self, x):
x = self.embedding(x)
for layer in self.layers:
x = layer(x)
x = self.moe_layer(x, self.num_experts_per_tok)
logits = self.output_layer(x)
return logits
# Initialize the model with configurations matching Mixtral 8x7B
model = TransformerWithMoE(
num_layers=32, # Number of transformer layers
dim=4096, # Dimension of the model
head_dim=128, # Dimension of each head in the multi-head attention mechanisms
hidden_dim=14336, # Hidden dimensionality in the feed-forward network within the transformer
n_heads=32, # Number of attention heads
num_experts=8, # Number of experts in the MoE layer
vocab_size=32000, # Vocabulary size for the embedding layer
num_experts_per_tok=2 # Number of experts activated per token
)
@prakhyat123
Copy link

Good work on the clean Code. I do have few queries/recommendation below

https://mistral.ai/news/mixtral-of-experts/
From the Mixtral Release Blog: "Mixtral is a sparse mixture-of-experts network. It is a decoder-only model where the feedforward block picks from a set of 8 distinct groups of parameters. At every layer, for every token, a router network chooses two of these groups (the “experts”) to process the token and combine their output additively."

  1. We need to Use Decoder Transformer Architecture instead of Encoder.

  2. Sparse Mixture of Experts is applied at "every layer" of the transformer block, In the above code, We are applying the self.moe_layer() only once after all the block layers are processed by the transformer block.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment