Created
April 11, 2022 21:15
-
-
Save nbroad1881/d974a8e931cd179fd46191fe5750fd6c to your computer and use it in GitHub Desktop.
As studied in https://arxiv.org/pdf/1905.09788.pdf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def reinit_model_weights(model, n_layers, config): | |
# use whatever you named your transformer module | |
backbone = model.backbone | |
encoder_layers = backbone.encoder.layer | |
reinit_layers(encoder_layers, n_layers, std) | |
# use whatever you named the output | |
reinit_modules([model.output], std) | |
def reinit_layers(layers, n_layers, std): | |
for layer in layers[-n_layers:]: | |
reinit_modules(layer.modules(), std) | |
def reinit_modules(modules, std, reinit_embeddings=False): | |
for module in modules: | |
if isinstance(module, torch.nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif reinit_embeddings and isinstance(module, torch.nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
elif isinstance(module, torch.nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment