Last active
April 5, 2018 17:29
-
-
Save wassname/bd6a9c570e35bf8e9e161aecfcaa8d59 to your computer and use it in GitHub Desktop.
pytorch implementation of BatchRenorm2d (for Conv2d)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.autograd as autograd | |
from torch.nn.parameter import Parameter | |
from torch.autograd import Variable | |
def r_d_max_func(itr): | |
"Default max r and d provider as recommended in paper." | |
if itr < 5000: | |
return 1, 0 | |
if itr < 40000: | |
r_max = 2 / 35000 * (itr - 5000) + 1 | |
else: | |
r_max = 3 | |
if itr < 25000: | |
d_max = 5 / 20000 * (itr - 5000) | |
else: | |
d_max = 5 | |
return r_max, d_max | |
class BatchRenorm2d(nn.Module): | |
""" | |
BatchRenorm2d (for Conv2d) | |
Args: | |
- num_features: number of features, e.g. (3,32,32) for a 32,32,3 image. | |
- r_d_max_func: a function that takes the interation and returns the max r and d for that timestep. | |
By default this uses the reccomended values | |
- eps: epsilon parameter to avoid divide by zero | |
- momentum: momentum parameter from batch norm | |
Usage: | |
bn = BatchRenorm2d((3,4,2)) | |
x = Variable(torch.rand((10,3,4,2))) | |
bn(x).size() # (10,3,4,2) | |
Url: | |
- https://gist.github.com/wassname/bd6a9c570e35bf8e9e161aecfcaa8d59 | |
Refs: | |
- modified from rarilurelo's code (thanks) https://github.com/rarilurelo/batch_renormalization | |
- original paper: https://arxiv.org/pdf/1702.03275.pdf | |
""" | |
def __init__(self, | |
num_features, | |
r_d_max_func=r_d_max_func, | |
eps=1e-5, | |
momentum=0.1): | |
super(BatchRenorm2d, self).__init__() | |
self.num_features = num_features | |
# self.affine = affine | |
self.r_d_func = r_d_func | |
self.eps = eps | |
self.momentum = momentum | |
self.register_buffer( | |
'running_mean', | |
torch.zeros(num_features).unsqueeze(-1).unsqueeze(-1)) | |
self.register_buffer( | |
'running_var', | |
torch.ones(num_features).unsqueeze(-1).unsqueeze(-1)) | |
self.reset_parameters() | |
self.steps = 0 | |
def reset_parameters(self): | |
self.running_mean.zero_() | |
self.running_var.fill_(1) | |
def _check_input_dim(self, input): | |
if input.size(1) != self.running_mean.nelement(): | |
raise ValueError('got {}-feature tensor, expected {}' | |
.format(input.size(1), self.num_features)) | |
def forward(self, input, itr=None): | |
self._check_input_dim(input) | |
self.steps += 1 | |
itr = itr or self.steps | |
# flatten to (batch*height*width, channels) | |
input_flat = input.transpose(-1, 1).contiguous().view((-1, | |
input.size(1))) | |
# Calculate batch/norm statistics | |
mean_b = input_flat.mean(0).unsqueeze(-1).unsqueeze(-1).expand_as( | |
input) | |
var_b = input_flat.var(0).unsqueeze(-1).unsqueeze(-1).expand_as( | |
input) + self.eps | |
bn = (input - mean_b) / var_b | |
# Calculate factors r and b | |
r_max, d_max = self.r_d_func(itr) | |
var = Variable(self.running_var.unsqueeze(0)).expand_as(input) | |
mean = Variable(self.running_mean.unsqueeze(0)).expand_as(input) | |
r = var_b / var | |
d = (mean_b - mean) / var | |
# Clamp, then detach to stop gradient | |
r = r.clamp(1 / r_max, r_max).detach() | |
d = d.clamp(-d_max, d_max).detach() | |
# Update moving stats | |
self.running_mean = self.running_mean + self.momentum * ( | |
mean_b.data.mean(0) - self.running_mean) | |
self.running_var = self.running_var + self.momentum * ( | |
var_b.data.mean(0) - self.running_var) | |
return bn * r + d | |
def __repr__(self): | |
return ('{name}({num_features}, eps={eps}, momentum={momentum},' | |
' affine={affine})'.format( | |
name=self.__class__.__name__, **self.__dict__)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment