Created
August 23, 2019 02:52
-
-
Save nilbot/952d63ff5bbf16fe0de4a50d94dde228 to your computer and use it in GitHub Desktop.
reset_params()
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def reset_parameters(self): | |
""" | |
Properly initialize the weights, following the same recipe as: | |
Xavier init: http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf | |
Kaiming init: https://arxiv.org/abs/1502.01852 | |
""" | |
# initialize weights such that E[w_ij]=0 and Var[w_ij]=1/d | |
d = self.weight.size(0) | |
val_range = (3.0/d)**0.5 | |
self.weight.data.uniform_(-val_range, val_range) | |
w = self.weight.data.view(d, -1, self.n_out, self.k) | |
if self.n_proj > 0: | |
val_range_2 = (3.0/self.weight_proj.size(0))**0.5 | |
self.weight_proj.data.uniform_(-val_range_2, val_range_2) | |
# initialize bias | |
self.bias.data.zero_() | |
bias_val, n_out = self.highway_bias, self.n_out | |
if self.bidirectional: | |
self.bias.data[n_out*2:].zero_().add_(bias_val) | |
else: | |
self.bias.data[n_out:].zero_().add_(bias_val) | |
if not self.v1: | |
# intialize weight_c such that E[w]=0 and Var[w]=1 | |
self.weight_c.data.uniform_(-3.0**0.5, 3.0**0.5) | |
# rescale weight_c and the weight of sigmoid gates with a factor of sqrt(0.5) | |
w[:, :, :, 1].mul_(0.5**0.5) | |
w[:, :, :, 2].mul_(0.5**0.5) | |
self.weight_c.data.mul_(0.5**0.5) | |
else: | |
self.weight_c.data.zero_() | |
self.weight_c.requires_grad = False | |
self.scale_x.data[0] = 1 | |
if not self.rescale: | |
return | |
# scalar used to properly scale the highway output | |
scale_val = (1+math.exp(bias_val)*2)**0.5 | |
self.scale_x.data[0] = scale_val | |
if self.k == 4: | |
w[:, :, :, 3].mul_(scale_val) | |
# re-scale weights for dropout and normalized input for better gradient flow | |
if self.dropout > 0: | |
w[:, :, :, 0].mul_((1-self.dropout)**0.5) | |
if self.rnn_dropout > 0: | |
w.mul_((1-self.rnn_dropout)**0.5) | |
if self.is_input_normalized: | |
w[:, :, :, 1].mul_(0.1) | |
w[:, :, :, 2].mul_(0.1) | |
self.weight_c.data.mul_(0.1) | |
# re-parameterize when weight normalization is enabled | |
if self.weight_norm: | |
self.reset_weight_norm() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment