Skip to content

Instantly share code, notes, and snippets.

@gautierdag
Created July 5, 2021 10:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gautierdag/3bd64f33470cb11f4323ce7fa86524a9 to your computer and use it in GitHub Desktop.
Save gautierdag/3bd64f33470cb11f4323ce7fa86524a9 to your computer and use it in GitHub Desktop.
Pytorch Bert Layer-wise Learning Rate Decay
import torch
from torch.optim import AdamW
from transformers import AutoModel
def get_bert_layerwise_lr_groups(bert_model, learning_rate=1e-5, layer_decay=0.9):
"""
Gets parameter groups with decayed learning rate based on depth in network
Layers closer to output will have higher learning rate
Args:
bert_model: A huggingface bert-like model (should have embedding layer and encoder)
learning_rate: The learning rate at the output layer
layer_decay: How much to decay the learning rate per depth (recommended 0.9-0.95)
Returns:
grouped_parameters (list): list of parameters with their decayed learning rates
"""
n_layers = len(bert_model.encoder.layer) + 1 # + 1 (embedding)
embedding_decayed_lr = learning_rate * (layer_decay ** (n_layers+1))
grouped_parameters = [{"params": bert_model.embeddings.parameters(), 'lr': embedding_decayed_lr}]
for depth in range(1, n_layers):
decayed_lr = learning_rate * (layer_decay ** (n_layers + 1 - depth))
grouped_parameters.append(
{"params": bert_model.encoder.layer[depth-1].parameters(), 'lr': decayed_lr}
)
return grouped_parameters
# Example:
model = AutoModel.from_pretrained("roberta-base")
lr_groups = get_bert_layerwise_lr_groups(model, learning_rate=1e-5)
optimizer = torch.optim.AdamW(
lr_groups, lr=1e-5, weight_decay=0
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment