Skip to content

Instantly share code, notes, and snippets.

@nilbot
Created August 23, 2019 02:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nilbot/952d63ff5bbf16fe0de4a50d94dde228 to your computer and use it in GitHub Desktop.
Save nilbot/952d63ff5bbf16fe0de4a50d94dde228 to your computer and use it in GitHub Desktop.
reset_params()
def reset_parameters(self):
"""
Properly initialize the weights, following the same recipe as:
Xavier init: http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
Kaiming init: https://arxiv.org/abs/1502.01852
"""
# initialize weights such that E[w_ij]=0 and Var[w_ij]=1/d
d = self.weight.size(0)
val_range = (3.0/d)**0.5
self.weight.data.uniform_(-val_range, val_range)
w = self.weight.data.view(d, -1, self.n_out, self.k)
if self.n_proj > 0:
val_range_2 = (3.0/self.weight_proj.size(0))**0.5
self.weight_proj.data.uniform_(-val_range_2, val_range_2)
# initialize bias
self.bias.data.zero_()
bias_val, n_out = self.highway_bias, self.n_out
if self.bidirectional:
self.bias.data[n_out*2:].zero_().add_(bias_val)
else:
self.bias.data[n_out:].zero_().add_(bias_val)
if not self.v1:
# intialize weight_c such that E[w]=0 and Var[w]=1
self.weight_c.data.uniform_(-3.0**0.5, 3.0**0.5)
# rescale weight_c and the weight of sigmoid gates with a factor of sqrt(0.5)
w[:, :, :, 1].mul_(0.5**0.5)
w[:, :, :, 2].mul_(0.5**0.5)
self.weight_c.data.mul_(0.5**0.5)
else:
self.weight_c.data.zero_()
self.weight_c.requires_grad = False
self.scale_x.data[0] = 1
if not self.rescale:
return
# scalar used to properly scale the highway output
scale_val = (1+math.exp(bias_val)*2)**0.5
self.scale_x.data[0] = scale_val
if self.k == 4:
w[:, :, :, 3].mul_(scale_val)
# re-scale weights for dropout and normalized input for better gradient flow
if self.dropout > 0:
w[:, :, :, 0].mul_((1-self.dropout)**0.5)
if self.rnn_dropout > 0:
w.mul_((1-self.rnn_dropout)**0.5)
if self.is_input_normalized:
w[:, :, :, 1].mul_(0.1)
w[:, :, :, 2].mul_(0.1)
self.weight_c.data.mul_(0.1)
# re-parameterize when weight normalization is enabled
if self.weight_norm:
self.reset_weight_norm()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment