Skip to content

Instantly share code, notes, and snippets.

@wdhorton
Created April 22, 2018 17:48
Show Gist options
  • Save wdhorton/a35ee80e695e1422e86c3e493cb503c2 to your computer and use it in GitHub Desktop.
Save wdhorton/a35ee80e695e1422e86c3e493cb503c2 to your computer and use it in GitHub Desktop.
params = []
# callback for storing the params of the model after each epoch
class SaveModelParams(Callback):
def __init__(self, model):
self.model = model
def on_epoch_end(self, metrics):
params.append([p.data.cpu().numpy() for p in self.model.parameters()])
# basic setup and training of the model
net2 = SimpleNet([32*32*3, 40, 10])
learn2 = ConvLearner.from_model_data(net2, data)
lr = 2e-2
learn2.fit(lr, 3, use_swa=True, callbacks=[SaveModelParams(learn2.model)])
# grab the params from the SWA model
swa_model_params = [p.data.cpu().numpy() for p in learn2.swa_model.parameters()]
for p_model1, p_model2, p_model3, p_swa_model in zip(*params, swa_model_params):
# check for equality up to a certain tolerance
print(np.isclose(p_swa_model, np.mean(np.stack([p_model1, p_model2, p_model3]), axis=0)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment