Skip to content

Instantly share code, notes, and snippets.

@KeremTurgutlu
Last active July 26, 2022 03:10
Show Gist options
  • Save KeremTurgutlu/4ec36c40843cbbef0710e1d6e1c83151 to your computer and use it in GitHub Desktop.
Save KeremTurgutlu/4ec36c40843cbbef0710e1d6e1c83151 to your computer and use it in GitHub Desktop.
EMA and SWA callbacks for different model averaging techniques
from fastai.vision.all import *
__all__ = ["EMA", "SWA"]
class EMA(Callback):
"https://fastai.github.io/timmdocs/training_modelEMA"
order,run_valid=5,False
def __init__(self, decay=0.9999):
super().__init__()
self.decay = decay
self.switched = False
def before_fit(self):
if not hasattr(self, "ema_model"):
print("Init EMA model")
self.ema_model = deepcopy(self.learn.model)
for param_k in self.ema_model.parameters():
param_k.requires_grad = False
@torch.no_grad()
def _update(self):
for param_k, param_q in zip(self.ema_model.parameters(), self.learn.model.parameters()):
param_k.data = param_k.data * self.decay + param_q.data * (1. - self.decay)
def after_step(self):
"Momentum update target model"
self._update()
def switch_model(self):
if self.switched:
self.learn.model = self.original_model
self.switched = False
print("Switched to original model")
else:
self.original_model = self.learn.model
self.learn.model = self.ema_model
self.switched = True
print("Switched to EMA model")
class SWA(Callback):
"https://arxiv.org/pdf/1803.05407.pdf (Use with fit_sgdr_*)"
order,run_valid=5,False
def __init__(self, pcts:List[float], swa_start=0):
"""
pcts: pcts of end of each cycle in terms of pct_train
swa_start: at which cycle end to start averaging
"""
super().__init__()
self.swa_start = swa_start
self.pcts = pcts
self.switched = False
self.swa_n = 0
self.curr_pct_idx = 0
self.curr_pct = self.pcts[self.curr_pct_idx]
def before_fit(self):
print("training mode:",self.learn.training)
if not hasattr(self, "swa_model"):
print("Init SWA model")
self.swa_model = deepcopy(self.learn.model)
for param_k in self.swa_model.parameters():
param_k.requires_grad = False
def after_step(self):
"Update SWA model at given pcts"
if (self.pct_train >= self.curr_pct):
print(f"Updating swa model at pct_train: {self.pct_train}")
self.update_average_model()
self.curr_pct_idx += 1
self.curr_pct = self.pcts[self.curr_pct_idx]
def after_fit(self):
"Average final checkpoint if it wasn't"
if (np.round(self.pct_train, 2) >= self.curr_pct) and (self.curr_pct_idx == len(self.pcts)-1):
print(f"Updating final swa model at pct_train: {self.pct_train}")
self.update_average_model()
self.curr_pct_idx += 1
def update_average_model(self):
# update running average of parameters
for model_param, swa_param in zip(self.model.parameters(), self.swa_model.parameters()):
swa_param.data = (swa_param.data*self.swa_n + model_param.data) / (self.swa_n + 1)
self.swa_n += 1
def switch_model(self):
if self.switched:
self.learn.model = self.original_model
self.switched = False
print("Switched to original model")
else:
self.original_model = self.learn.model
self.learn.model = self.swa_model
self.switched = True
print("Switched to SWA model")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment