Created March 2, 2021 22:54
Fastai WANDB Callback with DDP
def main(
size: Param("Image resolution", int)=224,
bs: Param("Batch Size", int)=128,
epochs: Param("Number of epochs for training", int)=1,
lr: Param("Learning rate for training", float)=5e-5):
WANDB = True
# start wandb
if rank_distrib() == 0 and WANDB:
wandb.init(project="xxx", entity="xxx");
"Compute":"Single GPU Non Distributed Loss",
"Training":"From Scratch"});
# dataloaders
dls, clip_tokenizer = get_dls(cids, sample_valid_cids[:10000], size, bs)
if rank_distrib() == 0: print(len(dls.train_ds), len(dls.valid_ds))
# callbacks
ndata = len(dls.train_ds)//1000
modelname = f'xxx'
savemodel_cb = SaveModelCallback(monitor="retrieval_at_20", comp=np.greater, fname=modelname)
if num_distrib()>0:
print("Distributed training mode")
clip_trainer_cb = DistributedCLIPTrainer()
print("Single gpu training mode")
clip_trainer_cb = CLIPTrainer()
cbs = [savemodel_cb, clip_trainer_cb]
if rank_distrib() == 0 and WANDB: cbs += [WandbCallback(log_preds=False, log_model=False)]
# model
vitb32_config_dict = vitb32_config(size, clip_tokenizer.context_length, clip_tokenizer.vocab_size)
clip_model = CustomCLIP(**vitb32_config_dict)
learner = Learner(dls, clip_model, loss_func=noop, cbs=cbs,
# fit
if num_distrib()>0:
with learner.distrib_ctx():
print(f"num_distrib(): {num_distrib()}")
lr *= math.sqrt(num_distrib())
learner.fit_flat_cos(epochs, lr, pct_start=0.25)
else: learner.fit_flat_cos(epochs, lr, pct_start=0.25)
# end wandb
if rank_distrib() == 0: wandb.finish()
