Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active February 28, 2021 22:45
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save crowsonkb/76b94d5238272722290734bf4725d204 to your computer and use it in GitHub Desktop.
Save crowsonkb/76b94d5238272722290734bf4725d204 to your computer and use it in GitHub Desktop.
Parameter averaging for PyTorch
"""Exponential moving average for PyTorch. Adapted from
https://www.zijianhu.com/post/pytorch/ema/.
"""
from copy import deepcopy
import torch
from torch import nn
class EMA(nn.Module):
def __init__(self, model, decay):
super().__init__()
self.model = model
self.decay = decay
self.register_buffer('accum', torch.tensor(1.))
self._biased = deepcopy(self.model)
self.average = deepcopy(self.model)
for param in self._biased.parameters():
param.detach_().zero_()
for param in self.average.parameters():
param.detach_().zero_()
self.update()
@torch.no_grad()
def update(self):
if not self.training:
raise RuntimeError('Update should only be called during training')
self.accum *= self.decay
model_params = dict(self.model.named_parameters())
biased_params = dict(self._biased.named_parameters())
average_params = dict(self.average.named_parameters())
assert model_params.keys() == biased_params.keys() == average_params.keys()
for name, param in model_params.items():
biased_params[name].mul_(self.decay)
biased_params[name].add_((1 - self.decay) * param)
average_params[name].copy_(biased_params[name])
average_params[name].div_(1 - self.accum)
model_buffers = dict(self.model.named_buffers())
biased_buffers = dict(self._biased.named_buffers())
average_buffers = dict(self.average.named_buffers())
assert model_buffers.keys() == biased_buffers.keys() == average_buffers.keys()
for name, buffer in model_buffers.items():
biased_buffers[name].copy_(buffer)
average_buffers[name].copy_(buffer)
def forward(self, *args, **kwargs):
if self.training:
return self.model(*args, **kwargs)
return self.average(*args, **kwargs)
"""Polynomial-decay averaging for PyTorch. See https://arxiv.org/abs/1212.1824."""
from copy import deepcopy
import torch
from torch import nn
class PDA(nn.Module):
def __init__(self, model, eta=0):
super().__init__()
self.model = model
self.eta = eta
self.register_buffer('t', torch.tensor(0))
self.average = deepcopy(self.model)
for param in self.average.parameters():
param.detach_().zero_()
self.update()
@torch.no_grad()
def update(self):
if not self.training:
raise RuntimeError('Update should only be called during training')
self.t += 1
model_params = dict(self.model.named_parameters())
average_params = dict(self.average.named_parameters())
assert model_params.keys() == average_params.keys()
weight = (1 + self.eta) / (self.t + self.eta)
for name, param in model_params.items():
average_params[name].mul_(1 - weight)
average_params[name].add_(weight * param)
model_buffers = dict(self.model.named_buffers())
average_buffers = dict(self.average.named_buffers())
assert model_buffers.keys() == average_buffers.keys()
for name, buffer in model_buffers.items():
average_buffers[name].copy_(buffer)
def forward(self, *args, **kwargs):
if self.training:
return self.model(*args, **kwargs)
return self.average(*args, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment