Created
March 13, 2024 12:31
-
-
Save simon-lund/15b7e518d6d9df54c66b101b0a2cac08 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is a test network moe with gradient checkpointing. | |
There are three blocks, all wrapped with gradient checkpointing. | |
The first block only contains a router and the "original layer". | |
The other two blocks contain the "original layer", and two expert layers. | |
We have two setups, one where the router scores are returned by the block's forward method and passed down to the other blocks. | |
The second setup is where the router scores are set directly on the last two blocks as attributes. They are then to be used in the forward method of the expert aggregator of the last two blocks. | |
""" | |
import pytest | |
import torch | |
import torch.nn as nn | |
from torch.utils.checkpoint import checkpoint | |
class BlockWithRouter(torch.nn.Module): | |
def __init__(self, refs=None): | |
super().__init__() | |
self.refs = refs | |
self.router = nn.Linear(100, 2) | |
self.original = nn.Linear(100, 100) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
scores = self.router(x) | |
# Set scores on the refs | |
if self.refs is not None: | |
for ref in self.refs: | |
ref.scores = scores | |
return self.relu(self.original(x)), scores | |
class BlockWithExperts(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.scores = None | |
self.original = nn.Linear(100, 100) | |
self.relu = nn.ReLU() | |
self.expert1 = nn.Linear(100, 100) | |
self.expert2 = nn.Linear(100, 100) | |
def forward(self, x, scores=None): | |
scores = scores if scores is not None else self.scores | |
# Use the scores to aggregate the expert outputs | |
expert1_output = self.expert1(x) | |
expert2_output = self.expert2(x) | |
expert_outputs = torch.stack([expert1_output, expert2_output]) | |
result = torch.einsum('be,ebt->bt', scores, expert_outputs) | |
return self.relu(self.original(x) + result) | |
class WithoutGradChkp(torch.nn.Module): | |
""" | |
Baseline w/o gradient checkpointing like setup2. | |
""" | |
def __init__(self): | |
super().__init__() | |
self.relu = nn.ReLU() | |
self.block3 = BlockWithExperts() | |
self.block2 = BlockWithExperts() | |
self.block1 = BlockWithRouter() | |
def forward(self, x): | |
x, scores = self.block1(x) | |
x = self.block2(x, scores) | |
x = self.block3(x, scores) | |
return x.softmax(dim=-1) | |
class Setup1(torch.nn.Module): | |
""" | |
Router scores are returned by the block's forward method and passed down to the other blocks. | |
""" | |
def __init__(self): | |
super().__init__() | |
self.relu = nn.ReLU() | |
self.block3 = BlockWithExperts() | |
self.block2 = BlockWithExperts() | |
self.block1 = BlockWithRouter() | |
def forward(self, x): | |
x, scores = checkpoint(self.block1.__call__, x, use_reentrant=False) | |
x = checkpoint(self.block2.__call__, x, scores, use_reentrant=False) | |
x = checkpoint(self.block3.__call__, x, scores, use_reentrant=False) | |
return x.softmax(dim=-1) | |
class Setup2(torch.nn.Module): | |
""" | |
Router scores are set directly on the last two blocks as attributes. They are then to be used in the forward method of the expert aggregator of the last two blocks. | |
""" | |
def __init__(self): | |
super().__init__() | |
self.block3 = BlockWithExperts() | |
self.block2 = BlockWithExperts() | |
self.block1 = BlockWithRouter(refs=[self.block2, self.block3]) | |
def forward(self, x): | |
x, _ = checkpoint(self.block1.__call__, x, use_reentrant=False) | |
x = checkpoint(self.block2.__call__, x, use_reentrant=False) | |
x = checkpoint(self.block3.__call__, x, use_reentrant=False) | |
return x.softmax(dim=-1) | |
def randomtrain(model): | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
criterion = torch.nn.CrossEntropyLoss() | |
# Train | |
model.train() | |
for _ in range(100): | |
x = torch.randn(10, 100) | |
y = torch.randint(0, 2, (10,)) | |
optimizer.zero_grad() | |
out = model(x) | |
loss = criterion(out, y) | |
loss.backward() | |
optimizer.step() | |
def test_without_gradchkp(): | |
""" | |
Test to see if router weights are updated when using gradient checkpointing with the scores set as attributes. | |
""" | |
model = WithoutGradChkp() | |
randomtrain(model) | |
def test_setup1(): | |
""" | |
Test to see if router weights are updated when using gradient checkpointing with the scores passed as arguments. | |
""" | |
model = Setup1() | |
randomtrain(model) | |
# Save the weights for the router model to compare them after training | |
router_weights = model.block1.router.weight.clone().detach() | |
randomtrain(model) | |
# Check if the router weights have been updated | |
assert not torch.equal(router_weights, model.block1.router.weight), "Router weights have not been updated." | |
def test_setup2(): | |
""" | |
Test to see if router weights are updated when using gradient checkpointing with the scores set as attributes. | |
""" | |
# Set up training | |
model = Setup2() | |
# Save the weights for the router model to compare them after training | |
router_weights = model.block1.router.weight.clone().detach() | |
randomtrain(model) | |
# Check if the router weights have been updated | |
assert not torch.equal(router_weights, model.block1.router.weight), "Router weights have not been updated." |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment