Skip to content

Instantly share code, notes, and snippets.

Created March 11, 2021 01:53
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
What would you like to do?
ZeRO optimizer example
import wandb
from fastai.callback.wandb import WandbCallback
from fastai.distributed import *
torch.backends.cudnn.benchmark = True
from zero_optimizer import ZeroRedundancyOptimizer
def after_batch(self: WandbCallback):
"Log hyper-parameters and training loss"
self._wandb_step += 1
self._wandb_epoch += 1/self.n_iter
hypers = {f'{k}_{i}':v for i,h in enumerate(self.opt.hypers) for k,v in h.items()}
wandb.log({'epoch': self._wandb_epoch,
'train_loss': self.smooth_loss.clone().detach().cpu(),
'raw_loss': self.loss.clone().detach().cpu()},
def main(
size: Param("Image resolution", int)=224,
bs: Param("Batch Size", int)=256,
epochs: Param("Number of epochs for training", int)=1,
lr: Param("Learning rate for training", float)=5e-5,
opt: Param("Optimizer to use", str)="zero",
WANDB = True
# start wandb
if rank_distrib() == 0 and WANDB:
wandb.init(project="XXX", entity="XXX");
"Optimizer": opt,
"Compute":"Multi GPU Distributed Loss",
"Training":"From Scratch"});
# dataloaders
dls, clip_tokenizer = get_dls(cids, sample_valid_cids[:10000], size, bs)
if rank_distrib() == 0: print(len(dls.train_ds), len(dls.valid_ds))
# callbacks
ndata = len(dls.train_ds)//1000
modelname = f'XXX_{ndata}K_en_vitb32_bs{bs}_size{size}_epochs{epochs}_lr{lr}'
savemodel_cb = SaveModelCallback(monitor="retrieval_at_20", comp=np.greater, fname=modelname)
if num_distrib()>0:
print("Distributed training mode")
# clip_trainer_cb = DistributedCLIPTrainer()
clip_trainer_cb = CLIPTrainer()
print("Single gpu training mode")
clip_trainer_cb = CLIPTrainer()
cbs = [savemodel_cb, clip_trainer_cb]
if rank_distrib() == 0 and WANDB: cbs += [WandbCallback(log_preds=False, log_model=False)]
# ZeRO
def zero(params, lr, **kwargs):
return OptimWrapper(ZeroRedundancyOptimizer(params, optimizer_class=torch.optim.Adam, lr=lr))
if opt == 'zero': opt_func = zero
elif opt == 'larc': opt_func = Larc
elif opt == 'ranger': opt_func = ranger
else: opt_func = Adam
# model
vitb32_config_dict = vitb32_config(size, clip_tokenizer.context_length, clip_tokenizer.vocab_size)
clip_model = CLIP(**vitb32_config_dict, checkpoint=True, checkpoint_nchunks=2)
learner = Learner(dls, clip_model, loss_func=noop, cbs=cbs, opt_func=opt_func,
# fit
if num_distrib()>0:
with learner.distrib_ctx():
print(f"num_distrib(): {num_distrib()}")
learner.fit_flat_cos(epochs, lr, pct_start=0.25)
else: learner.fit_flat_cos(epochs, lr, pct_start=0.25)
# end wandb
if rank_distrib() == 0: wandb.finish()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment