Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Created February 26, 2021 19:09
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save crowsonkb/e2f7a829b06fd74c5b7a7414ab015265 to your computer and use it in GitHub Desktop.
Save crowsonkb/e2f7a829b06fd74c5b7a7414ab015265 to your computer and use it in GitHub Desktop.
Biased EMA for PyTorch
"""Exponential moving average for PyTorch. Adapted from
https://www.zijianhu.com/post/pytorch/ema/.
"""
from copy import deepcopy
import torch
from torch import nn
class EMA(nn.Module):
def __init__(self, model, decay):
super().__init__()
self.model = model
self.decay = decay
self.average = deepcopy(self.model)
for param in self.average.parameters():
param.detach_()
@torch.no_grad()
def update(self):
if not self.training:
raise RuntimeError('Update should only be called during training')
model_params = dict(self.model.named_parameters())
average_params = dict(self.average.named_parameters())
assert model_params.keys() == average_params.keys()
for name, param in model_params.items():
average_params[name].mul_(self.decay)
average_params[name].add_((1 - self.decay) * param)
model_buffers = dict(self.model.named_buffers())
average_buffers = dict(self.average.named_buffers())
assert model_buffers.keys() == average_buffers.keys()
for name, buffer in model_buffers.items():
average_buffers[name].copy_(buffer)
def forward(self, *args, **kwargs):
if self.training:
return self.model(*args, **kwargs)
return self.average(*args, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment