Skip to content

Instantly share code, notes, and snippets.

@skaae
Last active May 11, 2016 13:02
Show Gist options
  • Save skaae/5faacedb9c5961136e82 to your computer and use it in GitHub Desktop.
Save skaae/5faacedb9c5961136e82 to your computer and use it in GitHub Desktop.
import numpy as np
import theano
import theano.tensor as T
from theano import ifelse
from .. import init
from .. import nonlinearities
from .base import Layer
__all__ = [
"BatchNormalizationLayer"
]
class BatchNormalizationLayer(Layer):
"""
Batch normalization Layer [1]
The user is required to setup updates for the learned parameters (Gamma
and Beta). The values nessesary for creating the updates can be
obtained by passing a dict as the moving_avg_hooks keyword to
get_output().
REF:
[1] http://arxiv.org/abs/1502.03167
:parameters:
- input_layer : `Layer` instance
The layer from which this layer will obtain its input
- nonlinearity : callable or None (default: lasagne.nonlinearities.rectify)
The nonlinearity that is applied to the layer activations. If None
is provided, the layer will be linear.
- epsilon : scalar float. Stabilizing training. Setting this too
close to zero will result in nans.
:usage:
>>> from lasagne.layers import InputLayer, BatchNormalizationLayer,
DenseLayer
>>> from lasagne.nonlinearities import linear, rectify
>>> l_in = InputLayer((100, 20))
l_dense = Denselayer(l_in, 50, nonlinearity=linear)
>>> l_bn = BatchNormalizationLayer(l_dense, nonlinearity=rectify)
>>> hooks, input, updates = {}, T.matrix, []
>>> l_out = l_bn.get_output(
input, deterministic=False, moving_avg_hooks=hooks)
>>> mulfac = 1.0/100.0
>>> batchnormparams = list(itertools.chain(
*[i[1] for i in hooks['BatchNormalizationLayer:movingavg']]))
>>> batchnormvalues = list(itertools.chain(
*[i[0] for i in hooks['BatchNormalizationLayer:movingavg']]))
>>> for tensor, param in zip(tensors, params):
updates.append((param, (1.0-mulfac)*param + mulfac*tensor))
# append updates to your normal update list
"""
def __init__(self, incoming,
gamma = init.Uniform([0.95, 1.05]),
beta = init.Constant(0.),
nonlinearity=nonlinearities.rectify,
epsilon = 0.001,
**kwargs):
super(BatchNormalizationLayer, self).__init__(incoming, **kwargs)
if nonlinearity is None:
self.nonlinearity = nonlinearities.identity
else:
self.nonlinearity = nonlinearity
self.num_units = int(np.prod(self.input_shape[1:]))
self.gamma = self.create_param(gamma, (self.num_units),
name="BatchNormalizationLayer:gamma")
self.beta = self.create_param(beta, (self.num_units),
name="BatchNormalizationLayer:beta")
self.epsilon = epsilon
self.mean_inference = theano.shared(
np.zeros((1, self.num_units), dtype=theano.config.floatX),
borrow=True,
broadcastable=(True, False))
self.mean_inference.name = "shared:mean"
self.variance_inference = theano.shared(
np.zeros((1, self.num_units), dtype=theano.config.floatX),
borrow=True,
broadcastable=(True, False))
self.variance_inference.name = "shared:variance"
def get_params(self):
return [self.gamma, self.beta]
def get_output_shape_for(self, input_shape):
return input_shape
def get_output_for(self, input, moving_avg_hooks=None,
deterministic=False, *args, **kwargs):
if input.ndim > 2:
output_shape = input.shape
input = input.flatten(2)
if deterministic is False:
m = T.mean(input, axis=0, keepdims=True)
v = T.sqrt(T.var(input, axis=0, keepdims=True)+self.epsilon)
m.name = "tensor:mean"
v.name = "tensor:variance"
key = "BatchNormalizationLayer:movingavg"
if key not in moving_avg_hooks:
moving_avg_hooks[key] = []
moving_avg_hooks[key].append(
[[m,v], [self.mean_inference, self.variance_inference]])
else:
m = self.mean_inference
v = self.variance_inference
input_hat = (input - m) / v # normalize
y = self.gamma*input_hat + self.beta # scale and shift
if input.ndim > 2:
y = T.reshape(y, output_shape)
return self.nonlinearity(y)
@DediGadot
Copy link

The bug is in line 119.
You decide whether to reshape the output based on input.ndim, yet you flattened the input in line 99.
I added a Boolean variable specifying whether the input was indeed reshaped or not.

Cheers,
David

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment