Skip to content

Instantly share code, notes, and snippets.

@KeremTurgutlu
Last active December 18, 2020 11:12
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save KeremTurgutlu/68f86b0d0e7eceac57f0ef9baf3fd1e3 to your computer and use it in GitHub Desktop.
Save KeremTurgutlu/68f86b0d0e7eceac57f0ef9baf3fd1e3 to your computer and use it in GitHub Desktop.
from fastai.vision.all import *
from torch.distributions import Beta
from copy import deepcopy
__all__ = ["ELR", "ELR_plusA", "ELR_plusB"]
class ELR(Callback):
'''
The selected values are β = 0.7 and λ = 3 for symmetric noise, β = 0.9 and λ = 1 for
assymetric noise on CIFAR-10, and β = 0.9 and λ = 7 CIFAR-100
https://arxiv.org/pdf/2007.00151.pdf - Algorithm 1
Warning: Might need grid search for β, λ
TODO: Try increasing λ as training progress, perhaps sched_exp
'''
run_valid=False
def __init__(self, reg_beta=0.9, reg_lambda=1, sched_func=None, lambda_min=None, lambda_max=None, num_classes=None):
store_attr()
def after_fit(self):
self.learn.loss_func = self.old_lf
def before_fit(self):
#initialize empty targets
nc = ifnone(self.num_classes, self.dls.c)
self.t = torch.zeros(len(self.dls.train_ds), nc).to(self.dls.device)
self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf
def after_batch(self):
# if using scheduler
if self.sched_func is not None:
self.reg_lambda = self.sched_func(self.lambda_min, self.lambda_max, self.pct_train)
def after_pred(self):
# compute probas
self.y_pred = F.softmax(self.pred, dim=-1)
self.y_pred = torch.clamp(self.y_pred, 1e-4, 1.0-1e-4)
def after_step(self):
# get training batch indexes
idxs = self.dl._DataLoader__idxs
b_idxs = idxs[self.iter*self.dl.bs:(self.iter+1)*self.dl.bs]
# temporal ensembling
y_pred_ = self.y_pred.data.detach()
self.t[b_idxs] = self.reg_beta * self.t[b_idxs] + (1-self.reg_beta) * (y_pred_)
def lf(self, pred, *yb):
if not self.training: return self.old_lf(pred, *yb)
# get training batch indexes
idxs = self.dl._DataLoader__idxs
b_idxs = idxs[self.iter*self.dl.bs:(self.iter+1)*self.dl.bs]
ce_loss = F.cross_entropy(pred, *yb)
elr_reg = ((1-(self.t[b_idxs] * self.y_pred).sum(dim=1)).log()).mean()
return ce_loss + self.reg_lambda * elr_reg
class ELR_plusA(Callback):
'''
ELR + MixUp (Algorithm 2)
'''
run_after,run_valid = [Normalize],False
def __init__(self, reg_beta=0.9, reg_lambda=1, sched_func=None, lambda_min=None, lambda_max=None, num_classes=None,
alpha=0.4):
store_attr()
self.distrib = Beta(tensor(alpha), tensor(alpha))
def after_fit(self):
self.learn.loss_func = self.old_lf
def before_fit(self):
#initialize empty targets
nc = ifnone(self.num_classes, self.dls.c)
self.t = torch.zeros(len(self.dls.train_ds), nc).to(self.dls.device)
self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf
def before_batch(self):
# if using scheduler
if self.sched_func is not None:
self.reg_lambda = self.sched_func(self.lambda_min, self.lambda_max, self.pct_train)
# batch idxs
idxs = self.dl._DataLoader__idxs
b_idxs = idxs[self.iter*self.dl.bs:(self.iter+1)*self.dl.bs]
#mixup
lam = self.distrib.sample((self.y.size(0),)).squeeze().to(self.x.device)
lam = torch.stack([lam, 1-lam], 1)
self.lam = lam.max(1)[0]
shuffle = torch.randperm(self.y.size(0)).to(self.x.device)
# mixed preds : tb
nt_dims = len(self.t.size())
self.t1 = tuple(L(self.t[b_idxs]).itemgot(shuffle))
self.tb = tuple(L(self.t1,self.t[b_idxs]).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nt_dims-1)))[0]
# shuffled targets
self.yb1 = tuple(L(self.yb).itemgot(shuffle))
# raw inputs for temporal ensembling
self.x_raw = self.x.clone()
# mixed inputs : xb
# import pdb; pdb.set_trace()
nx_dims = len(self.x.size())
xb1 = tuple(L(self.xb).itemgot(shuffle))
self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1)))
def lf(self, pred, *yb):
if not self.training: return self.old_lf(pred, *yb)
# mixed cross entropy loss
ce1,ce = F.cross_entropy(pred, *self.yb1, reduction='none'), F.cross_entropy(pred, *self.yb, reduction='none')
ce_loss = torch.lerp(ce1, ce, self.lam).mean()
# mixed regularization
y_pred = F.softmax(pred, dim=-1)
y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4)
elr_reg = ((1-(self.tb * y_pred).sum(dim=1)).log()).mean()
return ce_loss + self.reg_lambda * elr_reg
def after_step(self):
# get training batch indexes
idxs = self.dl._DataLoader__idxs
b_idxs = idxs[self.iter*self.dl.bs:(self.iter+1)*self.dl.bs]
# network evaluation & temporal ensembling
with torch.no_grad():
y_pred = self.model.eval()(self.x_raw)
y_pred = F.softmax(self.pred, dim=-1)
y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4)
y_pred_ = y_pred.data.detach()
self.t[b_idxs] = self.reg_beta * self.t[b_idxs] + (1-self.reg_beta) * (y_pred_)
def sigmoid_rampup(pos):
"""Exponential rampup"""
phase = 1.0 - pos
return float(np.exp(-5.0 * phase * phase))
class ELR_plusB(Callback):
'''
https://arxiv.org/pdf/2007.00151.pdf - Algorithm 2: ELR + MixUp + Weight Averaging (No 2 models)
'''
run_after,run_valid = [Normalize],False
def __init__(self, reg_beta=0.9, reg_lambda=1,
sched_func=None, lambda_min=None, lambda_max=None,
num_classes=None,
mixup_alpha=1., ema_alpha=0.997):
store_attr()
self.distrib = Beta(tensor(mixup_alpha), tensor(mixup_alpha))
def after_fit(self):
self.learn.loss_func = self.old_lf
def before_fit(self):
#initialize empty targets
nc = ifnone(self.num_classes, self.dls.c)
self.t = torch.zeros(len(self.dls.train_ds), nc).to(self.dls.device)
self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf
# initialize ema target model
self.ema_model = deepcopy(learner.model).eval()
for p in self.ema_model.parameters(): p.data.zero_()
def before_batch(self):
# if using scheduler
if self.sched_func is not None: self.reg_lambda = self.sched_func(self.lambda_min, self.lambda_max, self.pct_train)
# batch idxs
idxs = self.dl._DataLoader__idxs
b_idxs = idxs[self.iter*self.dl.bs:(self.iter+1)*self.dl.bs]
#mixup
lam = self.distrib.sample((self.y.size(0),)).squeeze().to(self.x.device)
lam = torch.stack([lam, 1-lam], 1)
self.lam = lam.max(1)[0]
shuffle = torch.randperm(self.y.size(0)).to(self.x.device)
# mixed preds : tb
nt_dims = len(self.t.size())
self.t1 = tuple(L(self.t[b_idxs]).itemgot(shuffle))
self.tb = tuple(L(self.t1,self.t[b_idxs]).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nt_dims-1)))[0]
# shuffled targets
self.yb1 = tuple(L(self.yb).itemgot(shuffle))
# raw inputs for temporal ensembling
self.x_raw = self.x.clone()
# mixed inputs : xb
nx_dims = len(self.x.size())
xb1 = tuple(L(self.xb).itemgot(shuffle))
self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1)))
def lf(self, pred, *yb):
if not self.training: return self.old_lf(pred, *yb)
# mixed cross entropy loss
ce1,ce = F.cross_entropy(pred, *self.yb1, reduction='none'), F.cross_entropy(pred, *self.yb, reduction='none')
ce_loss = torch.lerp(ce1, ce, self.lam).mean()
# mixed regularization
y_pred = F.softmax(pred, dim=-1)
y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4)
elr_reg = ((1-(self.tb * y_pred).sum(dim=1)).log()).mean()
return ce_loss + self.reg_lambda * elr_reg
def after_step(self):
# get training batch indexes
idxs = self.dl._DataLoader__idxs
b_idxs = idxs[self.iter*self.dl.bs:(self.iter+1)*self.dl.bs]
# update target model
self.update_ema_model()
# network evaluation & temporal ensembling
with torch.no_grad():
y_pred = self.ema_model.eval()(self.x_raw)
y_pred = F.softmax(self.pred, dim=-1)
y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4)
y_pred_ = y_pred.data.detach()
self.t[b_idxs] = self.reg_beta * self.t[b_idxs] + (1-self.reg_beta) * (y_pred_)
def update_ema_model(self):
alpha = sigmoid_rampup(self.pct_train)*self.ema_alpha
for ema_param, param in zip(self.ema_model.parameters(), self.model.parameters()):
ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment