Skip to content

Instantly share code, notes, and snippets.

@crcrpar
Created September 20, 2018 14:50
Show Gist options
  • Save crcrpar/b5c7978f2c62d289da3578a3859c22f5 to your computer and use it in GitHub Desktop.
Save crcrpar/b5c7978f2c62d289da3578a3859c22f5 to your computer and use it in GitHub Desktop.
import chainer
from chainer.backends import cuda
import chainer.link_hooks
from chainer.link_hooks import _ForwardPreprocessCallbackArgs
from chainer import variable
def _get_axis(ndim, axis):
axes = [axis]
for i in range(ndim):
if i != axis:
axes.append(i)
return axes
def _norm_except_axis(xp, weight, axis, eps=1e-12):
ndim = weight.ndim
all_axes = list(range(ndim))
transposed_axes = _get_axis(ndim, axis)
weight_T = xp.transpose(weight, transposed_axes)
squared_weight_norm = xp.sum(weight_T * weight_T, axis=all_axes[1:])
weight_norm = xp.sqrt(squared_weight_norm + eps)
return weight_norm
class WeightNormalization(chainer.LinkHook):
name = 'WeightNormalization'
axis = 0
eps = 1e-12
def __init__(self, weight_dependent_initialization=False,
axis=None, eps=None):
if not weight_dependent_initialization:
raise NotImplementedError(
'Weight dependent initialization is not supported now.'
)
self._wd_init = weight_dependent_initialization
self._initialized = not self._wd_init
if axis is None:
axis = WeightNormalization.axis
if eps is None:
eps = WeightNormalization.eps
self.axis = axis
self.eps = eps
def added(self, link):
"""Register `g` and copy `W` to `_V`."""
W = link.W
if W is None:
raise ValueError(
"Because link's weight is not initialized,"
"it is impossible to apply Weight Normalization via LinkHook."
)
xp = link.xp
initial_g = _norm_except_axis(xp, W_array, self.axis, self.eps)
param_g = variable.Parameter(initial_g)
link._V = W
link.g = param_g
def deleted(self, link):
del link._V
del link.g
def forward_preprocess(self, cb_args):
assert isinstance(cb_args, _ForwardPreprocessCallbackArgs)
link = cb_args.link
xp = link.xp
weight = link._V
norm = _norm_except_axis(xp, W, self.axis, self.eps)
normalized_weight = link.g * weight / norm
link.W = normalized_weight
def _weight_dependent_initialization(self, *args, **kwargs):
raise NotImplementedError
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment