Last active
December 18, 2020 11:12
-
-
Save KeremTurgutlu/68f86b0d0e7eceac57f0ef9baf3fd1e3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.vision.all import * | |
from torch.distributions import Beta | |
from copy import deepcopy | |
__all__ = ["ELR", "ELR_plusA", "ELR_plusB"] | |
class ELR(Callback): | |
''' | |
The selected values are β = 0.7 and λ = 3 for symmetric noise, β = 0.9 and λ = 1 for | |
assymetric noise on CIFAR-10, and β = 0.9 and λ = 7 CIFAR-100 | |
https://arxiv.org/pdf/2007.00151.pdf - Algorithm 1 | |
Warning: Might need grid search for β, λ | |
TODO: Try increasing λ as training progress, perhaps sched_exp | |
''' | |
run_valid=False | |
def __init__(self, reg_beta=0.9, reg_lambda=1, sched_func=None, lambda_min=None, lambda_max=None, num_classes=None): | |
store_attr() | |
def after_fit(self): | |
self.learn.loss_func = self.old_lf | |
def before_fit(self): | |
#initialize empty targets | |
nc = ifnone(self.num_classes, self.dls.c) | |
self.t = torch.zeros(len(self.dls.train_ds), nc).to(self.dls.device) | |
self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf | |
def after_batch(self): | |
# if using scheduler | |
if self.sched_func is not None: | |
self.reg_lambda = self.sched_func(self.lambda_min, self.lambda_max, self.pct_train) | |
def after_pred(self): | |
# compute probas | |
self.y_pred = F.softmax(self.pred, dim=-1) | |
self.y_pred = torch.clamp(self.y_pred, 1e-4, 1.0-1e-4) | |
def after_step(self): | |
# get training batch indexes | |
idxs = self.dl._DataLoader__idxs | |
b_idxs = idxs[self.iter*self.dl.bs:(self.iter+1)*self.dl.bs] | |
# temporal ensembling | |
y_pred_ = self.y_pred.data.detach() | |
self.t[b_idxs] = self.reg_beta * self.t[b_idxs] + (1-self.reg_beta) * (y_pred_) | |
def lf(self, pred, *yb): | |
if not self.training: return self.old_lf(pred, *yb) | |
# get training batch indexes | |
idxs = self.dl._DataLoader__idxs | |
b_idxs = idxs[self.iter*self.dl.bs:(self.iter+1)*self.dl.bs] | |
ce_loss = F.cross_entropy(pred, *yb) | |
elr_reg = ((1-(self.t[b_idxs] * self.y_pred).sum(dim=1)).log()).mean() | |
return ce_loss + self.reg_lambda * elr_reg | |
class ELR_plusA(Callback): | |
''' | |
ELR + MixUp (Algorithm 2) | |
''' | |
run_after,run_valid = [Normalize],False | |
def __init__(self, reg_beta=0.9, reg_lambda=1, sched_func=None, lambda_min=None, lambda_max=None, num_classes=None, | |
alpha=0.4): | |
store_attr() | |
self.distrib = Beta(tensor(alpha), tensor(alpha)) | |
def after_fit(self): | |
self.learn.loss_func = self.old_lf | |
def before_fit(self): | |
#initialize empty targets | |
nc = ifnone(self.num_classes, self.dls.c) | |
self.t = torch.zeros(len(self.dls.train_ds), nc).to(self.dls.device) | |
self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf | |
def before_batch(self): | |
# if using scheduler | |
if self.sched_func is not None: | |
self.reg_lambda = self.sched_func(self.lambda_min, self.lambda_max, self.pct_train) | |
# batch idxs | |
idxs = self.dl._DataLoader__idxs | |
b_idxs = idxs[self.iter*self.dl.bs:(self.iter+1)*self.dl.bs] | |
#mixup | |
lam = self.distrib.sample((self.y.size(0),)).squeeze().to(self.x.device) | |
lam = torch.stack([lam, 1-lam], 1) | |
self.lam = lam.max(1)[0] | |
shuffle = torch.randperm(self.y.size(0)).to(self.x.device) | |
# mixed preds : tb | |
nt_dims = len(self.t.size()) | |
self.t1 = tuple(L(self.t[b_idxs]).itemgot(shuffle)) | |
self.tb = tuple(L(self.t1,self.t[b_idxs]).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nt_dims-1)))[0] | |
# shuffled targets | |
self.yb1 = tuple(L(self.yb).itemgot(shuffle)) | |
# raw inputs for temporal ensembling | |
self.x_raw = self.x.clone() | |
# mixed inputs : xb | |
# import pdb; pdb.set_trace() | |
nx_dims = len(self.x.size()) | |
xb1 = tuple(L(self.xb).itemgot(shuffle)) | |
self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1))) | |
def lf(self, pred, *yb): | |
if not self.training: return self.old_lf(pred, *yb) | |
# mixed cross entropy loss | |
ce1,ce = F.cross_entropy(pred, *self.yb1, reduction='none'), F.cross_entropy(pred, *self.yb, reduction='none') | |
ce_loss = torch.lerp(ce1, ce, self.lam).mean() | |
# mixed regularization | |
y_pred = F.softmax(pred, dim=-1) | |
y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4) | |
elr_reg = ((1-(self.tb * y_pred).sum(dim=1)).log()).mean() | |
return ce_loss + self.reg_lambda * elr_reg | |
def after_step(self): | |
# get training batch indexes | |
idxs = self.dl._DataLoader__idxs | |
b_idxs = idxs[self.iter*self.dl.bs:(self.iter+1)*self.dl.bs] | |
# network evaluation & temporal ensembling | |
with torch.no_grad(): | |
y_pred = self.model.eval()(self.x_raw) | |
y_pred = F.softmax(self.pred, dim=-1) | |
y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4) | |
y_pred_ = y_pred.data.detach() | |
self.t[b_idxs] = self.reg_beta * self.t[b_idxs] + (1-self.reg_beta) * (y_pred_) | |
def sigmoid_rampup(pos): | |
"""Exponential rampup""" | |
phase = 1.0 - pos | |
return float(np.exp(-5.0 * phase * phase)) | |
class ELR_plusB(Callback): | |
''' | |
https://arxiv.org/pdf/2007.00151.pdf - Algorithm 2: ELR + MixUp + Weight Averaging (No 2 models) | |
''' | |
run_after,run_valid = [Normalize],False | |
def __init__(self, reg_beta=0.9, reg_lambda=1, | |
sched_func=None, lambda_min=None, lambda_max=None, | |
num_classes=None, | |
mixup_alpha=1., ema_alpha=0.997): | |
store_attr() | |
self.distrib = Beta(tensor(mixup_alpha), tensor(mixup_alpha)) | |
def after_fit(self): | |
self.learn.loss_func = self.old_lf | |
def before_fit(self): | |
#initialize empty targets | |
nc = ifnone(self.num_classes, self.dls.c) | |
self.t = torch.zeros(len(self.dls.train_ds), nc).to(self.dls.device) | |
self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf | |
# initialize ema target model | |
self.ema_model = deepcopy(learner.model).eval() | |
for p in self.ema_model.parameters(): p.data.zero_() | |
def before_batch(self): | |
# if using scheduler | |
if self.sched_func is not None: self.reg_lambda = self.sched_func(self.lambda_min, self.lambda_max, self.pct_train) | |
# batch idxs | |
idxs = self.dl._DataLoader__idxs | |
b_idxs = idxs[self.iter*self.dl.bs:(self.iter+1)*self.dl.bs] | |
#mixup | |
lam = self.distrib.sample((self.y.size(0),)).squeeze().to(self.x.device) | |
lam = torch.stack([lam, 1-lam], 1) | |
self.lam = lam.max(1)[0] | |
shuffle = torch.randperm(self.y.size(0)).to(self.x.device) | |
# mixed preds : tb | |
nt_dims = len(self.t.size()) | |
self.t1 = tuple(L(self.t[b_idxs]).itemgot(shuffle)) | |
self.tb = tuple(L(self.t1,self.t[b_idxs]).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nt_dims-1)))[0] | |
# shuffled targets | |
self.yb1 = tuple(L(self.yb).itemgot(shuffle)) | |
# raw inputs for temporal ensembling | |
self.x_raw = self.x.clone() | |
# mixed inputs : xb | |
nx_dims = len(self.x.size()) | |
xb1 = tuple(L(self.xb).itemgot(shuffle)) | |
self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1))) | |
def lf(self, pred, *yb): | |
if not self.training: return self.old_lf(pred, *yb) | |
# mixed cross entropy loss | |
ce1,ce = F.cross_entropy(pred, *self.yb1, reduction='none'), F.cross_entropy(pred, *self.yb, reduction='none') | |
ce_loss = torch.lerp(ce1, ce, self.lam).mean() | |
# mixed regularization | |
y_pred = F.softmax(pred, dim=-1) | |
y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4) | |
elr_reg = ((1-(self.tb * y_pred).sum(dim=1)).log()).mean() | |
return ce_loss + self.reg_lambda * elr_reg | |
def after_step(self): | |
# get training batch indexes | |
idxs = self.dl._DataLoader__idxs | |
b_idxs = idxs[self.iter*self.dl.bs:(self.iter+1)*self.dl.bs] | |
# update target model | |
self.update_ema_model() | |
# network evaluation & temporal ensembling | |
with torch.no_grad(): | |
y_pred = self.ema_model.eval()(self.x_raw) | |
y_pred = F.softmax(self.pred, dim=-1) | |
y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4) | |
y_pred_ = y_pred.data.detach() | |
self.t[b_idxs] = self.reg_beta * self.t[b_idxs] + (1-self.reg_beta) * (y_pred_) | |
def update_ema_model(self): | |
alpha = sigmoid_rampup(self.pct_train)*self.ema_alpha | |
for ema_param, param in zip(self.ema_model.parameters(), self.model.parameters()): | |
ema_param.data.mul_(alpha).add_(1 - alpha, param.data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment