Skip to content

Instantly share code, notes, and snippets.

@tchaton
Created July 20, 2021 15:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tchaton/7f9772a79bc9d773832cfb260caf7ba1 to your computer and use it in GitHub Desktop.
Save tchaton/7f9772a79bc9d773832cfb260caf7ba1 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from fairscale.nn import checkpoint_wrapper, auto_wrap, wrap
class MyModel(pl.LightningModule):
...
def configure_sharded_model(self):
# Created within sharded model context, modules are instantly sharded across processes
# as soon as they are wrapped with ``wrap`` or ``auto_wrap``
# Wraps the layer in a Fully Sharded Wrapper automatically
linear_layer = wrap(nn.Linear(32, 32))
# For best memory efficiency,
# add fairscale activation checkpointing
block = auto_wrap(
checkpoint_wrapper(
nn.Sequential(
nn.Linear(32, 32),
nn.ReLU()
)
)
)
self.model = nn.Sequential(
linear_layer,
nn.ReLU(),
block
)
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters())
model = MyModel()
trainer = Trainer(gpus=4, plugins='fsdp', precision=16)
trainer.fit(model)
trainer.test()
trainer.predict()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment