Skip to content

Instantly share code, notes, and snippets.

@KeremTurgutlu
Created June 14, 2021 07:55
Show Gist options
  • Save KeremTurgutlu/623723ffeb3399171cb7b3292aaebd45 to your computer and use it in GitHub Desktop.
Save KeremTurgutlu/623723ffeb3399171cb7b3292aaebd45 to your computer and use it in GitHub Desktop.
from fastai.vision.all import *
from torch.cuda.amp import autocast, GradScaler
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
from sam import SAM
class FastaiSched:
def __init__(self, optimizer, max_lr):
self.optimizer = optimizer
self.lr_sched = combine_scheds([0.1,0.9], [SchedLin(1e-8,max_lr), SchedCos(max_lr,1e-8)])
self.update(0)
def update(self, pos):
for param_group in self.optimizer.param_groups:
param_group["lr"] = self.lr_sched(pos)
# print("lr set to:", param_group["lr"])
class ProgressTracker:
def __init__(self, dls, total_epochs):
self.iter = 0
self.epoch = 0
self.total_iter = len(dls.train)*total_epochs
@property
def train_pct(self): return self.iter/self.total_iter
from torch.distributions import Beta
def mixup_batch(xb, yb, lam):
shuffle = torch.randperm(yb.size(0)).to(xb.device)
xb1,yb1 = tuple(L(xb).itemgot(shuffle)),tuple(L(yb).itemgot(shuffle))
nx_dims = len(xb.size())
xb = tuple(L(xb1,xb).map_zip(torch.lerp,weight=unsqueeze(lam, n=nx_dims-1)))
return xb[0],yb,yb1[0]
def epoch_train(model, dls, loss_fn, optimizer, scheduler, progress_tracker, grad_clip_max_norm=None, mixup=False, mixup_alpha=0.4):
model.train()
losses = []
if mixup: beta_distrib = Beta(tensor(mixup_alpha), tensor(mixup_alpha))
for xb, yb in progress_bar(dls.train):
if mixup:
lam = beta_distrib.sample((yb.size(0),)).squeeze().to(xb.device)
xb,yb,yb1 = mixup_batch(xb,yb,lam)
# first forward-backward pass
with torch.cuda.amp.autocast():
out = model(xb)
if mixup:
with NoneReduce(loss_fn) as lf:
loss = torch.lerp(lf(pred,yb1), lf(pred,yb), lam)
else:
loss = loss_fn(out, yb)
loss.backward()
if grad_clip_max_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_max_norm)
optimizer.first_step(zero_grad=True)
# second forward-backward pass
with torch.cuda.amp.autocast():
out = model(xb)
if mixup:
with NoneReduce(loss_fn) as lf:
loss = torch.lerp(lf(pred,yb1), lf(pred,yb), lam)
else:
loss = loss_fn(out, yb)
loss.backward()
if grad_clip_max_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_max_norm)
optimizer.second_step(zero_grad=True)
losses.append(to_detach(loss))
# update optimizer params
progress_tracker.iter += 1
scheduler.update(progress_tracker.train_pct)
return losses
def epoch_validate(model, loss_fn, dls):
model.eval()
preds, targs = [],[]
with torch.no_grad():
for xb, yb in progress_bar(dls.valid):
pred = model(xb).cpu()
preds += [pred]
targs += [yb.cpu()]
preds, targs = torch.cat(preds), torch.cat(targs)
loss = loss_fn(preds, targs)
score = sklearn_mean_ap(preds.softmax(dim=1), targs)
return loss, score
def train_sam(arch, lr, epochs=15, bs=32, size=512, fastai_augs=False, cropped=False, re=False, easy=False,
mixup=False, save=False, debug=False, folds_range=None, ckpt=None, clip_max_norm=3.):
if folds_range is None:
if debug: folds_range = range(0,1)
else: folds_range = range(5)
if debug:
# overwrite during debug
save = False
epochs = 1
for fold_idx in folds_range:
print(f"training fold {fold_idx}")
# dls
dls = get_dls(fold_idx, train_df, study_ids, study_label_ids, bs=bs, size=size,
fastai_augs=fastai_augs, cropped=cropped, re=re, easy=easy, debug=debug)
# model
model = get_classification_model(arch)
model.to(default_device())
if save:
fname = f"{arch}"
fname += f"-sz{size}"
if re: fname += "-re"
if mixup: fname += "-mixup"
if cropped: fname += "-cropped"
fname += f"-SAM-fold{fold_idx}"
base_optimizer = torch.optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, rho=0.05, adaptive=True, lr=lr, momentum=0.9, weight_decay=5e-4)
loss_fn = LabelSmoothingCrossEntropyFlat()
scheduler = FastaiSched(optimizer,max_lr=lr)
progress_tracker = ProgressTracker(dls, epochs)
if ckpt:
load_model(f"models/{ckpt}-fold{fold_idx}.pth", model, None, device=default_device())
# training
best_score = 0
score_not_improved = 0
res = []
for epoch in range(epochs):
# train and validate epoch
train_losses = epoch_train(model, dls, loss_fn, optimizer, scheduler, progress_tracker,
grad_clip_max_norm=clip_max_norm, mixup=mixup)
valid_loss, valid_score = epoch_validate(model, loss_fn, dls)
# print logs
train_loss = torch.stack(train_losses).mean()
row = [epoch, train_loss.item(), valid_loss.item(), valid_score]
res.append(row)
print(f"epoch: {row[0]} train_loss: {row[1]} valid_loss:{row[2]} valid_score:{row[3]}")
# save model
if valid_score>best_score:
save_model(f"models/{fname}.pth", model, None)
best_score = valid_score
else:
score_not_improved += 1
# early stop
patience = 3
if score_not_improved>patience: break
# save logs
res_df = pd.DataFrame(res, columns=['epoch', 'train_loss', 'valid_loss', 'sklearn_mean_ap'])
res_df.to_csv(history/f"{fname}.csv", index=False)
del model,optimizer,dls
gc.collect()
torch.cuda.empty_cache()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment