Skip to content

Instantly share code, notes, and snippets.

@jessecambon
Last active June 23, 2022 19:16
Show Gist options
  • Save jessecambon/87898b36675335d8132207b6dddc1b6d to your computer and use it in GitHub Desktop.
Save jessecambon/87898b36675335d8132207b6dddc1b6d to your computer and use it in GitHub Desktop.
Deepspeed Reproducible Example
{
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 32,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 3,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true,
"cpu_offload": false
},
"zero_allow_untested_optimizer": true,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 3e-5,
"betas": [0.8, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 0
}
},
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.strategies import DeepSpeedStrategy
from deepspeed.ops.adam import FusedAdam
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return FusedAdam(self.parameters())
class DataModule(LightningDataModule):
def setup(self, stage=None) -> None:
self._dataloader = DataLoader(RandomDataset(32, 64), batch_size=1)
def train_dataloader(self):
return self._dataloader
def test_dataloader(self):
return self._dataloader
def val_dataloader(self):
return self._dataloader
if __name__ == "__main__":
model = BoringModel()
dm = DataModule()
trainer = Trainer(
gpus=2,
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
precision=16,
enable_model_summary=False,
strategy=DeepSpeedStrategy(config="deepspeed_config.json"),
deterministic=True
)
trainer.fit(model, datamodule=dm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment