Created
May 1, 2018 23:59
-
-
Save domarps/1c011322c511a5b994436ce9cf460b06 to your computer and use it in GitHub Desktop.
batchnorm fwd
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def batchnorm_forward(x, gamma, beta, bn_param): | |
""" | |
Forward pass for batch normalization. | |
During training the sample mean and (uncorrected) sample variance are | |
computed from minibatch statistics and used to normalize the incoming data. | |
During training we also keep an exponentially decaying running mean of the | |
mean and variance of each feature, and these averages are used to normalize | |
data at test-time. | |
At each timestep we update the running averages for mean and variance using | |
an exponential decay based on the momentum parameter: | |
running_mean = momentum * running_mean + (1 - momentum) * sample_mean | |
running_var = momentum * running_var + (1 - momentum) * sample_var | |
Note that the batch normalization paper suggests a different test-time | |
behavior: they compute sample mean and variance for each feature using a | |
large number of training images rather than using a running average. For | |
this implementation we have chosen to use running averages instead since | |
they do not require an additional estimation step; the torch7 | |
implementation of batch normalization also uses running averages. | |
Input: | |
- x: Data of shape (N, D) | |
- gamma: Scale parameter of shape (D,) | |
- beta: Shift paremeter of shape (D,) | |
- bn_param: Dictionary with the following keys: | |
- mode: 'train' or 'test'; required | |
- eps: Constant for numeric stability | |
- momentum: Constant for running mean / variance. | |
- running_mean: Array of shape (D,) giving running mean of features | |
- running_var Array of shape (D,) giving running variance of features | |
Returns a tuple of: | |
- out: of shape (N, D) | |
- cache: A tuple of values needed in the backward pass | |
""" | |
mode = bn_param['mode'] | |
eps = bn_param.get('eps', 1e-5) | |
momentum = bn_param.get('momentum', 0.9) | |
N, D = x.shape | |
running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype)) | |
running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype)) | |
out, cache = None, None | |
if mode == 'train': | |
####################################################################### | |
# TODO: Implement the training-time forward pass for batch norm. # | |
# Use minibatch statistics to compute the mean and variance, use # | |
# these statistics to normalize the incoming data, and scale and # | |
# shift the normalized data using gamma and beta. # | |
# # | |
# You should store the output in the variable out. Any intermediates # | |
# that you need for the backward pass should be stored in the cache # | |
# variable. # | |
# # | |
# You should also use your computed sample mean and variance together # | |
# with the momentum variable to update the running mean and running # | |
# variance, storing your result in the running_mean and running_var # | |
# variables. # | |
# # | |
# Note that though you should be keeping track of the running # | |
# variance, you should normalize the data based on the standard # | |
# deviation (square root of variance) instead! # | |
# Referencing the original paper (https://arxiv.org/abs/1502.03167) # | |
# might prove to be helpful. # | |
####################################################################### | |
sample_mean = np.mean(x, axis = 0) # mini-batch mean | |
sample_var = np.var(x, axis = 0) # mini-batch variance | |
normalize_x = (x - sample_mean)/np.sqrt(sample_var + eps) | |
out = gamma * normalize_x + beta | |
#update running averages for mean and variance | |
running_mean = momentum * running_mean + (1 - momentum) * sample_mean | |
running_var = momentum * running_var + (1 - momentum) * sample_var | |
cache = { | |
'scaled_x' : (x - sample_mean), | |
'normalized_x' : normalize_x, | |
'gamma' : gamma, | |
'ivar' : 1./np.sqrt(sample_var + eps), | |
'sqrtvar' : np.sqrt(sample_var + eps) | |
} | |
####################################################################### | |
# END OF YOUR CODE # | |
####################################################################### | |
elif mode == 'test': | |
# During test time, examples are processed one at a time rather than one minibatch at a time | |
####################################################################### | |
# TODO: Implement the test-time forward pass for batch normalization. # | |
# Use the running mean and variance to normalize the incoming data, # | |
# then scale and shift the normalized data using gamma and beta. # | |
# Store the result in the out variable. # | |
####################################################################### | |
out = (gamma / (np.sqrt(running_var + eps)) * x) + (beta - (gamma * running_mean)/np.sqrt(running_var + eps)) | |
####################################################################### | |
# END OF YOUR CODE # | |
####################################################################### | |
else: | |
raise ValueError('Invalid forward batchnorm mode "%s"' % mode) | |
# Store the updated running means back into bn_param | |
bn_param['running_mean'] = running_mean | |
bn_param['running_var'] = running_var | |
return out, cache | |
def batchnorm_backward(dout, cache): | |
""" | |
Backward pass for batch normalization. | |
For this implementation, you should write out a computation graph for | |
batch normalization on paper and propagate gradients backward through | |
intermediate nodes. | |
Inputs: | |
- dout: Upstream derivatives, of shape (N, D) | |
- cache: Variable of intermediates from batchnorm_forward. | |
Returns a tuple of: | |
- dx: Gradient with respect to inputs x, of shape (N, D) | |
- dgamma: Gradient with respect to scale parameter gamma, of shape (D,) | |
- dbeta: Gradient with respect to shift parameter beta, of shape (D,) | |
""" | |
dx, dgamma, dbeta = None, None, None | |
########################################################################### | |
# TODO: Implement the backward pass for batch normalization. Store the # | |
# results in the dx, dgamma, and dbeta variables. # | |
# Referencing the original paper (https://arxiv.org/abs/1502.03167) # | |
# might prove to be helpful. # | |
########################################################################### | |
N, D = dout.shape | |
# Access cached tuples computed in forward pass | |
normalized_x = cache['normalized_x'] | |
gamma = cache['gamma'] | |
ivar = cache['ivar'] | |
scaled_x = cache['scaled_x'] | |
sqrtvar = cache['sqrtvar'] | |
#Please refer to the jupyter notebook for step-ids. | |
#Step 1 : backprop dout to calculate dbeta | |
dbeta = np.sum(dout, axis = 0) # (D,) | |
#Step 3 : backprop dout to calculate gamma_xhat | |
dgamma_xhat = dout # (N, D) | |
#Step 2 : backprop dgamma_xhat to calculate dgamma | |
dgamma = np.sum(dgamma_xhat * normalized_x, axis = 0) # (D,) | |
#Step 4 : backprop dgamma_xhat to calculate dxhat | |
dxhat = dgamma_xhat * gamma | |
#Step 10 : backprop dxhat to calculate dscaled_x | |
dscaled_x = dxhat * ivar | |
#Step 5 : backprop dxhat to calculate dix (short for gradient of the inverse of x) | |
dix = np.sum(dxhat * scaled_x, axis = 0) | |
#Step 6 : backprop dix to calculate dsigma | |
dsigma = -1 * dix * (1/sqrtvar ** 2) | |
#Step 7 : backprop dsigma to calculate dvar | |
dvar = 0.5 * dsigma * (1/sqrtvar) | |
#Step 8 : backprop dvar to calculate dsq | |
dsq = 1/N * dvar * np.ones_like(dout) | |
#Step 9 : backprop dsq to calculate dxmu | |
dxmu = dsq * 2 * scaled_x | |
#Step 12 : backprop dxmu and dscaled_x to calculate dmu | |
dmu = -1 * np.sum(dxmu + dscaled_x, axis = 0) | |
#Step 11 : backprop dxmu and dscaled_x to calculate dx | |
#Step 11.a | |
dx1 = dxmu + dscaled_x | |
dx2 = 1/N * dmu * np.ones_like(dout) | |
dx = dx1 + dx2 | |
########################################################################### | |
# END OF YOUR CODE # | |
########################################################################### | |
return dx, dgamma, dbeta | |
def batchnorm_backward_alt(dout, cache): | |
""" | |
Alternative backward pass for batch normalization. | |
For this implementation you should work out the derivatives for the batch | |
normalizaton backward pass on paper and simplify as much as possible. You | |
should be able to derive a simple expression for the backward pass. | |
See the jupyter notebook for more hints. | |
Note: This implementation should expect to receive the same cache variable | |
as batchnorm_backward, but might not use all of the values in the cache. | |
Inputs / outputs: Same as batchnorm_backward | |
""" | |
dx, dgamma, dbeta = None, None, None | |
########################################################################### | |
# TODO: Implement the backward pass for batch normalization. Store the # | |
# results in the dx, dgamma, and dbeta variables. # | |
# # | |
# After computing the gradient with respect to the centered inputs, you # | |
# should be able to compute gradients with respect to the inputs in a # | |
# single statement; our implementation fits on a single 80-character line.# | |
########################################################################### | |
N, D = dout.shape | |
normalized_x = cache.get('normalized_x') | |
gamma = cache.get('gamma') | |
ivar = cache.get('ivar') | |
scaled_x = cache.get('scaled_x') | |
sqrtvar = cache.get('sqrtvar') | |
# backprop dout to calculate dbeta and dgamma | |
dbeta = np.sum(dout, axis = 0) | |
dgamma = np.sum(dout * normalized_x, axis = 0) | |
dx = (1 / N) * gamma * 1/sqrtvar * ((N * dout) - np.sum(dout, axis=0) - ((scaled_x) * np.square(ivar) * np.sum(dout*scaled_x, axis=0))) | |
########################################################################### | |
# END OF YOUR CODE # | |
########################################################################### | |
print(dx.shape, dgamma.shape, dbeta.shape) | |
return dx, dgamma, dbeta |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment