Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save stefan-it/1eaf537a853aee6ab8fe5a425bf09ed7 to your computer and use it in GitHub Desktop.
Save stefan-it/1eaf537a853aee6ab8fe5a425bf09ed7 to your computer and use it in GitHub Desktop.
XLM-R conversion
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert RoBERTa checkpoint."""
from __future__ import absolute_import, division, print_function
import argparse
import logging
import numpy as np
import torch
import pathlib
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer
from transformers.modeling_bert import (BertConfig, BertEncoder,
BertIntermediate, BertLayer,
BertModel, BertOutput,
BertSelfAttention,
BertSelfOutput)
from transformers.modeling_roberta import (RobertaEmbeddings,
RobertaForMaskedLM,
RobertaForSequenceClassification,
RobertaModel)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = 'Hello world! cécé herlolip'
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head):
"""
Copy/paste/tweak roberta's weights to our BERT structure.
"""
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path, bpe = 'sentencepiece')
roberta.eval() # disable dropout
config = BertConfig(
vocab_size_or_config_json_file=250002,
hidden_size=roberta.args.encoder_embed_dim,
num_hidden_layers=roberta.args.encoder_layers,
num_attention_heads=roberta.args.encoder_attention_heads,
intermediate_size=roberta.args.encoder_ffn_embed_dim,
max_position_embeddings=514,
type_vocab_size=1,
layer_norm_eps=1e-5, # PyTorch default used in fairseq
)
if classification_head:
config.num_labels = roberta.args.num_classes
print("Our BERT config:", config)
model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config)
model.eval()
# Now let's copy all the weights.
# Embeddings
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them.
model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
for i in range(config.num_hidden_layers):
# Encoder: start of layer
layer: BertLayer = model.roberta.encoder.layer[i]
roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]
### self attention
self_attn: BertSelfAttention = layer.attention.self
assert(
roberta_layer.self_attn.k_proj.weight.data.shape == \
roberta_layer.self_attn.q_proj.weight.data.shape == \
roberta_layer.self_attn.v_proj.weight.data.shape == \
torch.Size((config.hidden_size, config.hidden_size))
)
self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight
self_attn.query.bias.data = roberta_layer.self_attn.q_proj.bias
self_attn.key.weight.data = roberta_layer.self_attn.k_proj.weight
self_attn.key.bias.data = roberta_layer.self_attn.k_proj.bias
self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight
self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias
### self-attention output
self_output: BertSelfOutput = layer.attention.output
assert(
self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
)
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
### intermediate
intermediate: BertIntermediate = layer.intermediate
assert(
intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
)
intermediate.dense.weight = roberta_layer.fc1.weight
intermediate.dense.bias = roberta_layer.fc1.bias
### output
bert_output: BertOutput = layer.output
assert(
bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
)
bert_output.dense.weight = roberta_layer.fc2.weight
bert_output.dense.bias = roberta_layer.fc2.bias
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
#### end of layer
if classification_head:
model.classifier.dense.weight = roberta.model.classification_heads['mnli'].dense.weight
model.classifier.dense.bias = roberta.model.classification_heads['mnli'].dense.bias
model.classifier.out_proj.weight = roberta.model.classification_heads['mnli'].out_proj.weight
model.classifier.out_proj.bias = roberta.model.classification_heads['mnli'].out_proj.bias
else:
# LM Head
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight
model.lm_head.bias = roberta.model.decoder.lm_head.bias
# Let's check that we get the same results.
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
our_output = model(input_ids)[0]
if classification_head:
their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids))
else:
their_output = roberta.model(input_ids)[0]
print(our_output.shape, their_output.shape)
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
success = torch.allclose(our_output, their_output, atol=1e-3)
print(
"Do both models output the same tensors?",
"🔥" if success else "💩"
)
if not success:
raise Exception("Something went wRoNg")
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--roberta_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the official PyTorch dump.")
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
parser.add_argument("--classification_head",
action = "store_true",
help = "Whether to convert a final classification head.")
args = parser.parse_args()
convert_roberta_checkpoint_to_pytorch(
args.roberta_checkpoint_path,
args.pytorch_dump_folder_path,
args.classification_head
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment