Skip to content

Instantly share code, notes, and snippets.

@nbroad1881
Created April 11, 2022 21:15
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nbroad1881/d974a8e931cd179fd46191fe5750fd6c to your computer and use it in GitHub Desktop.
Save nbroad1881/d974a8e931cd179fd46191fe5750fd6c to your computer and use it in GitHub Desktop.
def reinit_model_weights(model, n_layers, config):
# use whatever you named your transformer module
backbone = model.backbone
encoder_layers = backbone.encoder.layer
reinit_layers(encoder_layers, n_layers, std)
# use whatever you named the output
reinit_modules([model.output], std)
def reinit_layers(layers, n_layers, std):
for layer in layers[-n_layers:]:
reinit_modules(layer.modules(), std)
def reinit_modules(modules, std, reinit_embeddings=False):
for module in modules:
if isinstance(module, torch.nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif reinit_embeddings and isinstance(module, torch.nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, torch.nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment