Skip to content

Instantly share code, notes, and snippets.

@jojonki
Last active March 13, 2020 05:18
Show Gist options
  • Save jojonki/d78034ebb0bc798774d660458b3846e6 to your computer and use it in GitHub Desktop.
Save jojonki/d78034ebb0bc798774d660458b3846e6 to your computer and use it in GitHub Desktop.
Apply exponential moving average decay for variables in PyTorch
# How to apply exponential moving average decay for variables?
# https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/2
class EMA(nn.Module):
def __init__(self, mu):
super(EMA, self).__init__()
self.mu = mu
def forward(self,x, last_average):
new_average = self.mu*x + (1-self.mu)*last_average
return new_average
ema = EMA(0.999)
x = Variable(torch.rand(5),requires_grad=True)
average = Variable(torch.zeros(5),requires_grad=True)
average = ema(x, average)
class EMA(nn.Module):
def __init__(self, mu):
super(EMA, self).__init__()
self.mu = mu
self.shadow = {}
def register(self, name, val):
self.shadow[name] = val.clone()
def forward(self, name, x):
assert name in self.shadow
new_average = self.mu * x + (1.0 - self.mu) * self.shadow[name]
self.shadow[name] = new_average.clone()
return new_average
ema = EMA(0.999)
for name, param in model.named_parameters():
if param.requires_grad:
ema.register(name, param.data)
# in batch training loop
# for batch in batches:
optimizer.step()
for name, param in model.named_parameters():
if param.requires_grad:
param.data = ema(name, param.data)
@longxianlei
Copy link

It's wrong here. In most cases, the mu is used to sample the shadow, the estimated x_hat. and the (1.0 - mu) is used to sample to observation input_x.

@glenn-jocher
Copy link

I agree, I think this implementation is backwards.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment