Skip to content

Instantly share code, notes, and snippets.

@simon-lund
Created March 13, 2024 14:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save simon-lund/3aa518871b3765282a738d0b79b5ea22 to your computer and use it in GitHub Desktop.
Save simon-lund/3aa518871b3765282a738d0b79b5ea22 to your computer and use it in GitHub Desktop.
"""
This is a test network moe with gradient checkpointing.
There are three blocks, all wrapped with gradient checkpointing.
The first block only contains a router and the "original layer".
The other two blocks contain the "original layer", and two expert layers.
We have two setups, one where the router scores are returned by the block's forward method and passed down to the other blocks.
The second setup is where the router scores are set directly on the last two blocks as attributes. They are then to be used in the forward method of the expert aggregator of the last two blocks.
"""
import pytest
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class BlockWithRouter(torch.nn.Module):
def __init__(self, refs=None):
super().__init__()
self.refs = refs
self.router = nn.Linear(100, 2)
self.original = nn.Linear(100, 100)
self.relu = nn.ReLU()
def forward(self, x):
scores = self.router(x)
# Set scores on the refs
if self.refs is not None:
for ref in self.refs:
ref.scores = scores
return self.relu(self.original(x)), scores
class BlockWithExperts(torch.nn.Module):
def __init__(self):
super().__init__()
self.scores = None
self.original = nn.Linear(100, 100)
self.relu = nn.ReLU()
self.expert1 = nn.Linear(100, 100)
self.expert2 = nn.Linear(100, 100)
def forward(self, x, scores=None):
scores = scores if scores is not None else self.scores
# Use the scores to aggregate the expert outputs
expert1_output = self.expert1(x)
expert2_output = self.expert2(x)
expert_outputs = torch.stack([expert1_output, expert2_output])
result = torch.einsum('be,ebt->bt', scores, expert_outputs)
return self.relu(self.original(x) + result)
class WithoutGradChkp(torch.nn.Module):
"""
Baseline w/o gradient checkpointing like setup2.
"""
def __init__(self):
super().__init__()
self.relu = nn.ReLU()
self.block3 = BlockWithExperts()
self.block2 = BlockWithExperts()
self.block1 = BlockWithRouter()
def forward(self, x):
x, scores = self.block1(x)
x = self.block2(x, scores)
x = self.block3(x, scores)
return x.softmax(dim=-1)
class Setup1(torch.nn.Module):
"""
Router scores are returned by the block's forward method and passed down to the other blocks.
"""
def __init__(self):
super().__init__()
self.relu = nn.ReLU()
self.block3 = BlockWithExperts()
self.block2 = BlockWithExperts()
self.block1 = BlockWithRouter()
def forward(self, x):
x, scores = checkpoint(self.block1.__call__, x, use_reentrant=True)
x = checkpoint(self.block2.__call__, x, scores, use_reentrant=True)
x = checkpoint(self.block3.__call__, x, scores, use_reentrant=True)
return x.softmax(dim=-1)
class Setup2(torch.nn.Module):
"""
Router scores are set directly on the last two blocks as attributes. They are then to be used in the forward method of the expert aggregator of the last two blocks.
"""
def __init__(self):
super().__init__()
self.block3 = BlockWithExperts()
self.block2 = BlockWithExperts()
self.block1 = BlockWithRouter(refs=[self.block2, self.block3])
def forward(self, x):
x, _ = checkpoint(self.block1.__call__, x, use_reentrant=True)
x = checkpoint(self.block2.__call__, x, use_reentrant=True)
x = checkpoint(self.block3.__call__, x, use_reentrant=True)
return x.softmax(dim=-1)
def randomtrain(model):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
# Train
model.train()
for _ in range(100):
x = torch.randn(10, 100, requires_grad=True)
y = torch.randint(0, 2, (10,))
optimizer.zero_grad()
out = model(x)
loss = criterion(out, y)
loss.backward()
optimizer.step()
def test_without_gradchkp():
"""
Test to see if router weights are updated when using gradient checkpointing with the scores set as attributes.
"""
model = WithoutGradChkp()
randomtrain(model)
def test_setup1():
"""
Test to see if router weights are updated when using gradient checkpointing with the scores passed as arguments.
"""
model = Setup1()
randomtrain(model)
# Save the weights for the router model to compare them after training
router_weights = model.block1.router.weight.clone().detach()
randomtrain(model)
# Check if the router weights have been updated
assert not torch.equal(router_weights, model.block1.router.weight), "Router weights have not been updated."
def test_setup2():
"""
Test to see if router weights are updated when using gradient checkpointing with the scores set as attributes.
"""
# Set up training
model = Setup2()
# Save the weights for the router model to compare them after training
router_weights = model.block1.router.weight.clone().detach()
randomtrain(model)
# Check if the router weights have been updated
assert not torch.equal(router_weights, model.block1.router.weight), "Router weights have not been updated."
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment