Skip to content

Instantly share code, notes, and snippets.

@KeremTurgutlu
Created March 2, 2021 22:54
Show Gist options
  • Save KeremTurgutlu/58efb6890879ae16b09d6707e710517e to your computer and use it in GitHub Desktop.
Save KeremTurgutlu/58efb6890879ae16b09d6707e710517e to your computer and use it in GitHub Desktop.
Fastai WANDB Callback with DDP
@call_parse
def main(
size: Param("Image resolution", int)=224,
bs: Param("Batch Size", int)=128,
epochs: Param("Number of epochs for training", int)=1,
lr: Param("Learning rate for training", float)=5e-5):
WANDB = True
# start wandb
if rank_distrib() == 0 and WANDB:
wandb.init(project="xxx", entity="xxx");
wandb.config.update({"Arch":"ViT-B/32",
"Size":size,
"BS":bs,
"Compute":"Single GPU Non Distributed Loss",
"Training":"From Scratch"});
# dataloaders
dls, clip_tokenizer = get_dls(cids, sample_valid_cids[:10000], size, bs)
if rank_distrib() == 0: print(len(dls.train_ds), len(dls.valid_ds))
# callbacks
ndata = len(dls.train_ds)//1000
modelname = f'xxx'
savemodel_cb = SaveModelCallback(monitor="retrieval_at_20", comp=np.greater, fname=modelname)
if num_distrib()>0:
print("Distributed training mode")
clip_trainer_cb = DistributedCLIPTrainer()
else:
print("Single gpu training mode")
clip_trainer_cb = CLIPTrainer()
cbs = [savemodel_cb, clip_trainer_cb]
if rank_distrib() == 0 and WANDB: cbs += [WandbCallback(log_preds=False, log_model=False)]
# model
vitb32_config_dict = vitb32_config(size, clip_tokenizer.context_length, clip_tokenizer.vocab_size)
clip_model = CustomCLIP(**vitb32_config_dict)
learner = Learner(dls, clip_model, loss_func=noop, cbs=cbs,
metrics=[RetrievalAtK(k=5),
RetrievalAtK(k=20),
RetrievalAtK(k="mean"),
RetrievalAtK(k="median")])
learner.to_fp16()
learner.unfreeze()
# fit
if num_distrib()>0:
with learner.distrib_ctx():
print(f"num_distrib(): {num_distrib()}")
lr *= math.sqrt(num_distrib())
learner.fit_flat_cos(epochs, lr, pct_start=0.25)
else: learner.fit_flat_cos(epochs, lr, pct_start=0.25)
# end wandb
if rank_distrib() == 0: wandb.finish()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment