Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active April 5, 2018 17:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save wassname/bd6a9c570e35bf8e9e161aecfcaa8d59 to your computer and use it in GitHub Desktop.
Save wassname/bd6a9c570e35bf8e9e161aecfcaa8d59 to your computer and use it in GitHub Desktop.
pytorch implementation of BatchRenorm2d (for Conv2d)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.nn.parameter import Parameter
from torch.autograd import Variable
def r_d_max_func(itr):
"Default max r and d provider as recommended in paper."
if itr < 5000:
return 1, 0
if itr < 40000:
r_max = 2 / 35000 * (itr - 5000) + 1
else:
r_max = 3
if itr < 25000:
d_max = 5 / 20000 * (itr - 5000)
else:
d_max = 5
return r_max, d_max
class BatchRenorm2d(nn.Module):
"""
BatchRenorm2d (for Conv2d)
Args:
- num_features: number of features, e.g. (3,32,32) for a 32,32,3 image.
- r_d_max_func: a function that takes the interation and returns the max r and d for that timestep.
By default this uses the reccomended values
- eps: epsilon parameter to avoid divide by zero
- momentum: momentum parameter from batch norm
Usage:
bn = BatchRenorm2d((3,4,2))
x = Variable(torch.rand((10,3,4,2)))
bn(x).size() # (10,3,4,2)
Url:
- https://gist.github.com/wassname/bd6a9c570e35bf8e9e161aecfcaa8d59
Refs:
- modified from rarilurelo's code (thanks) https://github.com/rarilurelo/batch_renormalization
- original paper: https://arxiv.org/pdf/1702.03275.pdf
"""
def __init__(self,
num_features,
r_d_max_func=r_d_max_func,
eps=1e-5,
momentum=0.1):
super(BatchRenorm2d, self).__init__()
self.num_features = num_features
# self.affine = affine
self.r_d_func = r_d_func
self.eps = eps
self.momentum = momentum
self.register_buffer(
'running_mean',
torch.zeros(num_features).unsqueeze(-1).unsqueeze(-1))
self.register_buffer(
'running_var',
torch.ones(num_features).unsqueeze(-1).unsqueeze(-1))
self.reset_parameters()
self.steps = 0
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
def _check_input_dim(self, input):
if input.size(1) != self.running_mean.nelement():
raise ValueError('got {}-feature tensor, expected {}'
.format(input.size(1), self.num_features))
def forward(self, input, itr=None):
self._check_input_dim(input)
self.steps += 1
itr = itr or self.steps
# flatten to (batch*height*width, channels)
input_flat = input.transpose(-1, 1).contiguous().view((-1,
input.size(1)))
# Calculate batch/norm statistics
mean_b = input_flat.mean(0).unsqueeze(-1).unsqueeze(-1).expand_as(
input)
var_b = input_flat.var(0).unsqueeze(-1).unsqueeze(-1).expand_as(
input) + self.eps
bn = (input - mean_b) / var_b
# Calculate factors r and b
r_max, d_max = self.r_d_func(itr)
var = Variable(self.running_var.unsqueeze(0)).expand_as(input)
mean = Variable(self.running_mean.unsqueeze(0)).expand_as(input)
r = var_b / var
d = (mean_b - mean) / var
# Clamp, then detach to stop gradient
r = r.clamp(1 / r_max, r_max).detach()
d = d.clamp(-d_max, d_max).detach()
# Update moving stats
self.running_mean = self.running_mean + self.momentum * (
mean_b.data.mean(0) - self.running_mean)
self.running_var = self.running_var + self.momentum * (
var_b.data.mean(0) - self.running_var)
return bn * r + d
def __repr__(self):
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
' affine={affine})'.format(
name=self.__class__.__name__, **self.__dict__))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment