Last active
April 23, 2024 13:05
-
-
Save fergusq/bc8e4f2d3d74027df9232f05d08f2ce4 to your computer and use it in GitHub Desktop.
Convert Marian NMT models to PyTorch models (transformers library format)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2020 The HuggingFace Team. All rights reserved. | |
# | |
# Copyright 2024 Iikka Hauhio (removed OPUS/Tatoeba-specific code to allow converting arbitrary models) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import argparse | |
import json | |
import warnings | |
from pathlib import Path | |
from typing import Dict, List, Union | |
import numpy as np | |
import torch | |
from torch import nn | |
from transformers import MarianConfig, MarianMTModel, MarianTokenizer | |
def remove_suffix(text: str, suffix: str): | |
if text.endswith(suffix): | |
return text[: -len(suffix)] | |
return text # or whatever | |
def remove_prefix(text: str, prefix: str): | |
if text.startswith(prefix): | |
return text[len(prefix):] | |
return text # or whatever | |
def convert_encoder_layer(opus_dict, layer_prefix: str, converter: dict): | |
sd = {} | |
for k in opus_dict: | |
print("converting", k) | |
if not k.startswith(layer_prefix): | |
continue | |
stripped = remove_prefix(k, layer_prefix) | |
v = opus_dict[k].T # besides embeddings, everything must be transposed. | |
sd[converter[stripped]] = torch.tensor(v).squeeze() | |
return sd | |
def load_layers_(layer_lst: nn.ModuleList, opus_state: dict, converter, is_decoder=False): | |
for i, layer in enumerate(layer_lst): | |
layer_tag = f"decoder_l{i + 1}_" if is_decoder else f"encoder_l{i + 1}_" | |
sd = convert_encoder_layer(opus_state, layer_tag, converter) | |
layer.load_state_dict(sd, strict=False) | |
def add_emb_entries(wemb, final_bias, n_special_tokens=1): | |
vsize, d_model = wemb.shape | |
embs_to_add = np.zeros((n_special_tokens, d_model)) | |
new_embs = np.concatenate([wemb, embs_to_add]) | |
bias_to_add = np.zeros((n_special_tokens, 1)) | |
new_bias = np.concatenate((final_bias, bias_to_add), axis=1) | |
return new_embs, new_bias | |
def _cast_yaml_str(v): | |
bool_dct = {"true": True, "false": False} | |
if not isinstance(v, str): | |
return v | |
elif v in bool_dct: | |
return bool_dct[v] | |
try: | |
return int(v) | |
except (TypeError, ValueError): | |
return v | |
def cast_marian_config(raw_cfg: Dict[str, str]) -> Dict: | |
return {k: _cast_yaml_str(v) for k, v in raw_cfg.items()} | |
CONFIG_KEY = "special:model.yml" | |
def load_config_from_state_dict(opus_dict): | |
import yaml | |
cfg_str = "".join([chr(x) for x in opus_dict[CONFIG_KEY]]) | |
yaml_cfg = yaml.load(cfg_str[:-1], Loader=yaml.BaseLoader) | |
return cast_marian_config(yaml_cfg) | |
def find_model_file(dest_dir): # this one better | |
model_files = list(Path(dest_dir).glob("*.npz")) | |
if len(model_files) != 1: | |
raise ValueError(f"Found more than one model file: {model_files}") | |
model_file = model_files[0] | |
return model_file | |
def lmap(f, x) -> List: | |
return list(map(f, x)) | |
def add_to_vocab_(vocab: Dict[str, int], special_tokens: List[str]): | |
start = max(vocab.values()) + 1 | |
added = 0 | |
for tok in special_tokens: | |
if tok in vocab: | |
continue | |
vocab[tok] = start + added | |
added += 1 | |
return added | |
def find_vocab_file(model_dir: Path): | |
return [p for p in model_dir.glob("*.vocab*") if p.suffix in [".yml", ".vocab"]][0] | |
def find_src_vocab_file(model_dir: Path): | |
return [p for p in model_dir.glob("*.src.vocab*") if p.suffix in [".yml", ".vocab"]][0] | |
def find_tgt_vocab_file(model_dir: Path): | |
return [p for p in model_dir.glob("*.trg.vocab*") if p.suffix in [".yml", ".vocab"]][0] | |
def check_marian_cfg_assumptions(marian_cfg): | |
assumed_settings = { | |
"layer-normalization": False, | |
"right-left": False, | |
"transformer-ffn-depth": 2, | |
"transformer-aan-depth": 2, | |
"transformer-no-projection": False, | |
"transformer-postprocess-emb": "d", | |
"transformer-postprocess": "dan", # Dropout, add, normalize | |
"transformer-preprocess": "", | |
"type": "transformer", | |
"ulr-dim-emb": 0, | |
"dec-cell-base-depth": 2, | |
"dec-cell-high-depth": 1, | |
"transformer-aan-nogate": False, | |
} | |
for k, v in assumed_settings.items(): | |
actual = marian_cfg[k] | |
if actual != v: | |
raise ValueError(f"Unexpected config value for {k} expected {v} got {actual}") | |
BIAS_KEY = "decoder_ff_logit_out_b" | |
BART_CONVERTER = { # for each encoder and decoder layer | |
"self_Wq": "self_attn.q_proj.weight", | |
"self_Wk": "self_attn.k_proj.weight", | |
"self_Wv": "self_attn.v_proj.weight", | |
"self_Wo": "self_attn.out_proj.weight", | |
"self_bq": "self_attn.q_proj.bias", | |
"self_bk": "self_attn.k_proj.bias", | |
"self_bv": "self_attn.v_proj.bias", | |
"self_bo": "self_attn.out_proj.bias", | |
"self_Wo_ln_scale": "self_attn_layer_norm.weight", | |
"self_Wo_ln_bias": "self_attn_layer_norm.bias", | |
"ffn_W1": "fc1.weight", | |
"ffn_b1": "fc1.bias", | |
"ffn_W2": "fc2.weight", | |
"ffn_b2": "fc2.bias", | |
"ffn_ffn_ln_scale": "final_layer_norm.weight", | |
"ffn_ffn_ln_bias": "final_layer_norm.bias", | |
# Decoder Cross Attention | |
"context_Wk": "encoder_attn.k_proj.weight", | |
"context_Wo": "encoder_attn.out_proj.weight", | |
"context_Wq": "encoder_attn.q_proj.weight", | |
"context_Wv": "encoder_attn.v_proj.weight", | |
"context_bk": "encoder_attn.k_proj.bias", | |
"context_bo": "encoder_attn.out_proj.bias", | |
"context_bq": "encoder_attn.q_proj.bias", | |
"context_bv": "encoder_attn.v_proj.bias", | |
"context_Wo_ln_scale": "encoder_attn_layer_norm.weight", | |
"context_Wo_ln_bias": "encoder_attn_layer_norm.bias", | |
} | |
class OpusState: | |
def __init__(self, source_dir, source_lang: str, target_lang: str, source_spm: str, target_spm: str, eos_token_id=0): | |
self.source_lang = source_lang | |
self.target_lang = target_lang | |
self.source_spm = source_spm | |
self.target_spm = target_spm | |
npz_path = find_model_file(source_dir) | |
self.state_dict = np.load(npz_path) | |
cfg = load_config_from_state_dict(self.state_dict) | |
if cfg["dim-vocabs"][0] != cfg["dim-vocabs"][1]: | |
raise ValueError | |
if "Wpos" in self.state_dict: | |
raise ValueError("Wpos key in state dictionary") | |
self.state_dict = dict(self.state_dict) | |
if cfg["tied-embeddings-all"]: | |
cfg["tied-embeddings-src"] = True | |
cfg["tied-embeddings"] = True | |
self.share_encoder_decoder_embeddings = cfg["tied-embeddings-src"] | |
# create the tokenizer here because we need to know the eos_token_id | |
self.source_dir = source_dir | |
self.tokenizer = self.load_tokenizer() | |
# retrieve EOS token and set correctly | |
tokenizer_has_eos_token_id = ( | |
hasattr(self.tokenizer, "eos_token_id") and self.tokenizer.eos_token_id is not None | |
) | |
eos_token_id = self.tokenizer.eos_token_id if tokenizer_has_eos_token_id else 0 | |
if cfg["tied-embeddings-src"]: | |
self.wemb, self.final_bias = add_emb_entries(self.state_dict["Wemb"], self.state_dict[BIAS_KEY], 1) | |
self.pad_token_id = self.wemb.shape[0] - 1 | |
cfg["vocab_size"] = self.pad_token_id + 1 | |
else: | |
self.wemb, _ = add_emb_entries(self.state_dict["encoder_Wemb"], self.state_dict[BIAS_KEY], 1) | |
self.dec_wemb, self.final_bias = add_emb_entries( | |
self.state_dict["decoder_Wemb"], self.state_dict[BIAS_KEY], 1 | |
) | |
# still assuming that vocab size is same for encoder and decoder | |
self.pad_token_id = self.wemb.shape[0] - 1 | |
cfg["vocab_size"] = self.pad_token_id + 1 | |
cfg["decoder_vocab_size"] = self.pad_token_id + 1 | |
if cfg["vocab_size"] != self.tokenizer.vocab_size: | |
raise ValueError( | |
f"Original vocab size {cfg['vocab_size']} and new vocab size {len(self.tokenizer.encoder)} mismatched." | |
) | |
# self.state_dict['Wemb'].sha | |
self.state_keys = list(self.state_dict.keys()) | |
if "Wtype" in self.state_dict: | |
raise ValueError("Wtype key in state dictionary") | |
self._check_layer_entries() | |
self.cfg = cfg | |
hidden_size, intermediate_shape = self.state_dict["encoder_l1_ffn_W1"].shape | |
if hidden_size != cfg["dim-emb"]: | |
raise ValueError(f"Hidden size {hidden_size} and configured size {cfg['dim_emb']} mismatched") | |
# Process decoder.yml | |
if (source_dir / "decoder.yml").exists(): | |
decoder_yml = cast_marian_config(load_yaml_or_txt(source_dir / "decoder.yml")) | |
else: | |
decoder_yml = None | |
check_marian_cfg_assumptions(cfg) | |
self.hf_config = MarianConfig( | |
vocab_size=cfg["vocab_size"], | |
decoder_vocab_size=cfg.get("decoder_vocab_size", cfg["vocab_size"]), | |
share_encoder_decoder_embeddings=cfg["tied-embeddings-src"], | |
decoder_layers=cfg["dec-depth"], | |
encoder_layers=cfg["enc-depth"], | |
decoder_attention_heads=cfg["transformer-heads"], | |
encoder_attention_heads=cfg["transformer-heads"], | |
decoder_ffn_dim=cfg["transformer-dim-ffn"], | |
encoder_ffn_dim=cfg["transformer-dim-ffn"], | |
d_model=cfg["dim-emb"], | |
activation_function=cfg["transformer-ffn-activation"], | |
pad_token_id=self.pad_token_id, | |
eos_token_id=eos_token_id, | |
forced_eos_token_id=eos_token_id, | |
bos_token_id=0, | |
max_position_embeddings=cfg["dim-emb"], | |
scale_embedding=True, | |
normalize_embedding="n" in cfg["transformer-preprocess"], | |
static_position_embeddings=not cfg["transformer-train-position-embeddings"], | |
tie_word_embeddings=cfg["tied-embeddings"], | |
dropout=0.1, # see opus-mt-train repo/transformer-dropout param. | |
# default: add_final_layer_norm=False, | |
num_beams=decoder_yml["beam-size"] if decoder_yml is not None else 1, | |
decoder_start_token_id=self.pad_token_id, | |
bad_words_ids=[[self.pad_token_id]], | |
max_length=512, | |
) | |
def _check_layer_entries(self): | |
self.encoder_l1 = self.sub_keys("encoder_l1") | |
self.decoder_l1 = self.sub_keys("decoder_l1") | |
self.decoder_l2 = self.sub_keys("decoder_l2") | |
if len(self.encoder_l1) != 16: | |
warnings.warn(f"Expected 16 keys for each encoder layer, got {len(self.encoder_l1)}") | |
if len(self.decoder_l1) != 26: | |
warnings.warn(f"Expected 26 keys for each decoder layer, got {len(self.decoder_l1)}") | |
if len(self.decoder_l2) != 26: | |
warnings.warn(f"Expected 26 keys for each decoder layer, got {len(self.decoder_l1)}") | |
@property | |
def extra_keys(self): | |
extra = [] | |
for k in self.state_keys: | |
if ( | |
k.startswith("encoder_l") | |
or k.startswith("decoder_l") | |
or k in [CONFIG_KEY, "Wemb", "encoder_Wemb", "decoder_Wemb", "Wpos", "decoder_ff_logit_out_b"] | |
): | |
continue | |
else: | |
extra.append(k) | |
return extra | |
def sub_keys(self, layer_prefix): | |
return [remove_prefix(k, layer_prefix) for k in self.state_dict if k.startswith(layer_prefix)] | |
def load_tokenizer(self): | |
# save tokenizer | |
self.add_special_tokens_to_vocab(self.source_dir, not self.share_encoder_decoder_embeddings) | |
return MarianTokenizer.from_pretrained(str(self.source_dir)) | |
def add_special_tokens_to_vocab(self, model_dir: Path, separate_vocab=False) -> None: | |
print(separate_vocab) | |
if separate_vocab: | |
vocab = load_yaml_or_txt(find_src_vocab_file(model_dir)) | |
vocab = {k: int(v) for k, v in vocab.items()} | |
num_added = add_to_vocab_(vocab, ["<pad>"]) | |
save_json(vocab, model_dir / "vocab.json") | |
vocab = load_yaml_or_txt(find_tgt_vocab_file(model_dir)) | |
vocab = {k: int(v) for k, v in vocab.items()} | |
num_added = add_to_vocab_(vocab, ["<pad>"]) | |
save_json(vocab, model_dir / "target_vocab.json") | |
self.save_tokenizer_config(model_dir, separate_vocabs=separate_vocab) | |
else: | |
vocab = load_yaml_or_txt(find_vocab_file(model_dir)) | |
vocab = {k: int(v) for k, v in vocab.items()} | |
num_added = add_to_vocab_(vocab, ["<pad>"]) | |
print(f"added {num_added} tokens to vocab") | |
save_json(vocab, model_dir / "vocab.json") | |
self.save_tokenizer_config(model_dir) | |
def save_tokenizer_config(self, dest_dir: Path, separate_vocabs=False): | |
dct = { | |
"target_lang": self.target_lang, | |
"source_lang": self.source_lang, | |
"source_spm": self.source_spm, | |
"target_spm": self.target_spm, | |
"separate_vocabs": separate_vocabs | |
} | |
save_json(dct, dest_dir / "tokenizer_config.json") | |
def load_marian_model(self) -> MarianMTModel: | |
state_dict, cfg = self.state_dict, self.hf_config | |
if not cfg.static_position_embeddings: | |
raise ValueError("config.static_position_embeddings should be True") | |
model = MarianMTModel(cfg) | |
if "hidden_size" in cfg.to_dict(): | |
raise ValueError("hidden_size is in config") | |
load_layers_( | |
model.model.encoder.layers, | |
state_dict, | |
BART_CONVERTER, | |
) | |
load_layers_(model.model.decoder.layers, state_dict, BART_CONVERTER, is_decoder=True) | |
# handle tensors not associated with layers | |
if self.cfg["tied-embeddings-src"]: | |
wemb_tensor = nn.Parameter(torch.FloatTensor(self.wemb)) | |
bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias)) | |
assert model.model.shared is not None and model.model.encoder is not None | |
model.model.shared.weight = wemb_tensor | |
model.model.encoder.embed_tokens = model.model.decoder.embed_tokens = model.model.shared | |
else: | |
wemb_tensor = nn.Parameter(torch.FloatTensor(self.wemb)) | |
model.model.encoder.embed_tokens.weight = wemb_tensor | |
decoder_wemb_tensor = nn.Parameter(torch.FloatTensor(self.dec_wemb)) | |
bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias)) | |
model.model.decoder.embed_tokens.weight = decoder_wemb_tensor | |
model.final_logits_bias = bias_tensor | |
if "Wpos" in state_dict: | |
print("Unexpected: got Wpos") | |
wpos_tensor = torch.tensor(state_dict["Wpos"]) | |
model.model.encoder.embed_positions.weight = wpos_tensor | |
model.model.decoder.embed_positions.weight = wpos_tensor | |
if cfg.normalize_embedding: | |
if "encoder_emb_ln_scale_pre" not in state_dict: | |
raise ValueError("encoder_emb_ln_scale_pre is not in state dictionary") | |
raise NotImplementedError("Need to convert layernorm_embedding") | |
if self.extra_keys: | |
raise ValueError(f"Failed to convert {self.extra_keys}") | |
if model.get_input_embeddings().padding_idx != self.pad_token_id: | |
raise ValueError( | |
f"Padding tokens {model.get_input_embeddings().padding_idx} and {self.pad_token_id} mismatched" | |
) | |
return model | |
def convert(source_dir: Path, dest_dir: Path | str, source_lang: str, target_lang: str, source_spm: str, target_spm: str): | |
dest_dir = Path(dest_dir) | |
dest_dir.mkdir(exist_ok=True) | |
opus_state = OpusState(source_dir, source_lang, target_lang, source_spm, target_spm) | |
# save tokenizer | |
opus_state.tokenizer.save_pretrained(dest_dir) | |
# save_json(opus_state.cfg, dest_dir / "marian_original_config.json") | |
# ^^ Uncomment to save human readable marian config for debugging | |
model = opus_state.load_marian_model() | |
model = model.half() | |
model.save_pretrained(dest_dir) | |
model.from_pretrained(dest_dir) # sanity check | |
def load_yaml_or_txt(path: Path) -> dict: | |
import yaml | |
if path.suffix == ".yml": | |
with open(path, encoding="utf-8") as f: | |
return yaml.load(f, Loader=yaml.BaseLoader) | |
else: | |
with open(path, encoding="utf-8") as f: | |
return {k.split("\t")[0]: v for v, k in enumerate(f.read().strip().split("\n"))} | |
def save_json(content: Union[Dict, List], path: Path) -> None: | |
with open(path, "w") as f: | |
json.dump(content, f) | |
if __name__ == "__main__": | |
""" | |
Convert arbitrary Marian model to PyTorch. | |
""" | |
parser = argparse.ArgumentParser() | |
# Required parameters | |
parser.add_argument("--src", type=str, help="path to marian model sub dir") | |
parser.add_argument("--src-lang", type=str, help="source lang") | |
parser.add_argument("--tgt-lang", type=str, help="target lang") | |
parser.add_argument("--src-spm", type=str, default="source.spm", help="source sentencepiece model") | |
parser.add_argument("--tgt-spm", type=str, default="target.spm", help="target sentencepiece model") | |
parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model dir") | |
args = parser.parse_args() | |
source_dir = Path(args.src) | |
if not source_dir.exists(): | |
raise ValueError(f"Source directory {source_dir} not found") | |
dest_dir = f"converted-{source_dir.name}" if args.dest is None else args.dest | |
convert(source_dir, dest_dir, args.src_lang, args.tgt_lang, args.src_spm, args.tgt_spm) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment