Skip to content

Instantly share code, notes, and snippets.

@wdhorton
Created April 1, 2018 17:58
Show Gist options
  • Save wdhorton/d4ef94540bba9003340ed14843bd2803 to your computer and use it in GitHub Desktop.
Save wdhorton/d4ef94540bba9003340ed14843bd2803 to your computer and use it in GitHub Desktop.
SWA callback
class SWA(Callback):
def __init__(self, model, swa_model, swa_start):
super().__init__()
self.model,self.swa_model,self.swa_start=model,swa_model,swa_start
def on_train_begin(self):
self.epoch = 0
self.swa_n = 0
def on_epoch_end(self, metrics):
if (self.epoch + 1) >= self.swa_start:
self.update_average_model()
self.swa_n += 1
self.epoch += 1
def update_average_model(self):
# update running average of parameters
model_params = self.model.parameters()
swa_params = self.swa_model.parameters()
for model_param, swa_param in zip(model_params, swa_params):
swa_param.data *= self.swa_n
swa_param.data += model_param.data
swa_param.data /= (self.swa_n + 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment