Created
September 20, 2018 14:50
-
-
Save crcrpar/b5c7978f2c62d289da3578a3859c22f5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import chainer | |
from chainer.backends import cuda | |
import chainer.link_hooks | |
from chainer.link_hooks import _ForwardPreprocessCallbackArgs | |
from chainer import variable | |
def _get_axis(ndim, axis): | |
axes = [axis] | |
for i in range(ndim): | |
if i != axis: | |
axes.append(i) | |
return axes | |
def _norm_except_axis(xp, weight, axis, eps=1e-12): | |
ndim = weight.ndim | |
all_axes = list(range(ndim)) | |
transposed_axes = _get_axis(ndim, axis) | |
weight_T = xp.transpose(weight, transposed_axes) | |
squared_weight_norm = xp.sum(weight_T * weight_T, axis=all_axes[1:]) | |
weight_norm = xp.sqrt(squared_weight_norm + eps) | |
return weight_norm | |
class WeightNormalization(chainer.LinkHook): | |
name = 'WeightNormalization' | |
axis = 0 | |
eps = 1e-12 | |
def __init__(self, weight_dependent_initialization=False, | |
axis=None, eps=None): | |
if not weight_dependent_initialization: | |
raise NotImplementedError( | |
'Weight dependent initialization is not supported now.' | |
) | |
self._wd_init = weight_dependent_initialization | |
self._initialized = not self._wd_init | |
if axis is None: | |
axis = WeightNormalization.axis | |
if eps is None: | |
eps = WeightNormalization.eps | |
self.axis = axis | |
self.eps = eps | |
def added(self, link): | |
"""Register `g` and copy `W` to `_V`.""" | |
W = link.W | |
if W is None: | |
raise ValueError( | |
"Because link's weight is not initialized," | |
"it is impossible to apply Weight Normalization via LinkHook." | |
) | |
xp = link.xp | |
initial_g = _norm_except_axis(xp, W_array, self.axis, self.eps) | |
param_g = variable.Parameter(initial_g) | |
link._V = W | |
link.g = param_g | |
def deleted(self, link): | |
del link._V | |
del link.g | |
def forward_preprocess(self, cb_args): | |
assert isinstance(cb_args, _ForwardPreprocessCallbackArgs) | |
link = cb_args.link | |
xp = link.xp | |
weight = link._V | |
norm = _norm_except_axis(xp, W, self.axis, self.eps) | |
normalized_weight = link.g * weight / norm | |
link.W = normalized_weight | |
def _weight_dependent_initialization(self, *args, **kwargs): | |
raise NotImplementedError |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment