Skip to content

Instantly share code, notes, and snippets.

@fergusq
Last active April 23, 2024 13:05
Show Gist options
  • Save fergusq/bc8e4f2d3d74027df9232f05d08f2ce4 to your computer and use it in GitHub Desktop.
Save fergusq/bc8e4f2d3d74027df9232f05d08f2ce4 to your computer and use it in GitHub Desktop.
Convert Marian NMT models to PyTorch models (transformers library format)
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Copyright 2024 Iikka Hauhio (removed OPUS/Tatoeba-specific code to allow converting arbitrary models)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import warnings
from pathlib import Path
from typing import Dict, List, Union
import numpy as np
import torch
from torch import nn
from transformers import MarianConfig, MarianMTModel, MarianTokenizer
def remove_suffix(text: str, suffix: str):
if text.endswith(suffix):
return text[: -len(suffix)]
return text # or whatever
def remove_prefix(text: str, prefix: str):
if text.startswith(prefix):
return text[len(prefix):]
return text # or whatever
def convert_encoder_layer(opus_dict, layer_prefix: str, converter: dict):
sd = {}
for k in opus_dict:
print("converting", k)
if not k.startswith(layer_prefix):
continue
stripped = remove_prefix(k, layer_prefix)
v = opus_dict[k].T # besides embeddings, everything must be transposed.
sd[converter[stripped]] = torch.tensor(v).squeeze()
return sd
def load_layers_(layer_lst: nn.ModuleList, opus_state: dict, converter, is_decoder=False):
for i, layer in enumerate(layer_lst):
layer_tag = f"decoder_l{i + 1}_" if is_decoder else f"encoder_l{i + 1}_"
sd = convert_encoder_layer(opus_state, layer_tag, converter)
layer.load_state_dict(sd, strict=False)
def add_emb_entries(wemb, final_bias, n_special_tokens=1):
vsize, d_model = wemb.shape
embs_to_add = np.zeros((n_special_tokens, d_model))
new_embs = np.concatenate([wemb, embs_to_add])
bias_to_add = np.zeros((n_special_tokens, 1))
new_bias = np.concatenate((final_bias, bias_to_add), axis=1)
return new_embs, new_bias
def _cast_yaml_str(v):
bool_dct = {"true": True, "false": False}
if not isinstance(v, str):
return v
elif v in bool_dct:
return bool_dct[v]
try:
return int(v)
except (TypeError, ValueError):
return v
def cast_marian_config(raw_cfg: Dict[str, str]) -> Dict:
return {k: _cast_yaml_str(v) for k, v in raw_cfg.items()}
CONFIG_KEY = "special:model.yml"
def load_config_from_state_dict(opus_dict):
import yaml
cfg_str = "".join([chr(x) for x in opus_dict[CONFIG_KEY]])
yaml_cfg = yaml.load(cfg_str[:-1], Loader=yaml.BaseLoader)
return cast_marian_config(yaml_cfg)
def find_model_file(dest_dir): # this one better
model_files = list(Path(dest_dir).glob("*.npz"))
if len(model_files) != 1:
raise ValueError(f"Found more than one model file: {model_files}")
model_file = model_files[0]
return model_file
def lmap(f, x) -> List:
return list(map(f, x))
def add_to_vocab_(vocab: Dict[str, int], special_tokens: List[str]):
start = max(vocab.values()) + 1
added = 0
for tok in special_tokens:
if tok in vocab:
continue
vocab[tok] = start + added
added += 1
return added
def find_vocab_file(model_dir: Path):
return [p for p in model_dir.glob("*.vocab*") if p.suffix in [".yml", ".vocab"]][0]
def find_src_vocab_file(model_dir: Path):
return [p for p in model_dir.glob("*.src.vocab*") if p.suffix in [".yml", ".vocab"]][0]
def find_tgt_vocab_file(model_dir: Path):
return [p for p in model_dir.glob("*.trg.vocab*") if p.suffix in [".yml", ".vocab"]][0]
def check_marian_cfg_assumptions(marian_cfg):
assumed_settings = {
"layer-normalization": False,
"right-left": False,
"transformer-ffn-depth": 2,
"transformer-aan-depth": 2,
"transformer-no-projection": False,
"transformer-postprocess-emb": "d",
"transformer-postprocess": "dan", # Dropout, add, normalize
"transformer-preprocess": "",
"type": "transformer",
"ulr-dim-emb": 0,
"dec-cell-base-depth": 2,
"dec-cell-high-depth": 1,
"transformer-aan-nogate": False,
}
for k, v in assumed_settings.items():
actual = marian_cfg[k]
if actual != v:
raise ValueError(f"Unexpected config value for {k} expected {v} got {actual}")
BIAS_KEY = "decoder_ff_logit_out_b"
BART_CONVERTER = { # for each encoder and decoder layer
"self_Wq": "self_attn.q_proj.weight",
"self_Wk": "self_attn.k_proj.weight",
"self_Wv": "self_attn.v_proj.weight",
"self_Wo": "self_attn.out_proj.weight",
"self_bq": "self_attn.q_proj.bias",
"self_bk": "self_attn.k_proj.bias",
"self_bv": "self_attn.v_proj.bias",
"self_bo": "self_attn.out_proj.bias",
"self_Wo_ln_scale": "self_attn_layer_norm.weight",
"self_Wo_ln_bias": "self_attn_layer_norm.bias",
"ffn_W1": "fc1.weight",
"ffn_b1": "fc1.bias",
"ffn_W2": "fc2.weight",
"ffn_b2": "fc2.bias",
"ffn_ffn_ln_scale": "final_layer_norm.weight",
"ffn_ffn_ln_bias": "final_layer_norm.bias",
# Decoder Cross Attention
"context_Wk": "encoder_attn.k_proj.weight",
"context_Wo": "encoder_attn.out_proj.weight",
"context_Wq": "encoder_attn.q_proj.weight",
"context_Wv": "encoder_attn.v_proj.weight",
"context_bk": "encoder_attn.k_proj.bias",
"context_bo": "encoder_attn.out_proj.bias",
"context_bq": "encoder_attn.q_proj.bias",
"context_bv": "encoder_attn.v_proj.bias",
"context_Wo_ln_scale": "encoder_attn_layer_norm.weight",
"context_Wo_ln_bias": "encoder_attn_layer_norm.bias",
}
class OpusState:
def __init__(self, source_dir, source_lang: str, target_lang: str, source_spm: str, target_spm: str, eos_token_id=0):
self.source_lang = source_lang
self.target_lang = target_lang
self.source_spm = source_spm
self.target_spm = target_spm
npz_path = find_model_file(source_dir)
self.state_dict = np.load(npz_path)
cfg = load_config_from_state_dict(self.state_dict)
if cfg["dim-vocabs"][0] != cfg["dim-vocabs"][1]:
raise ValueError
if "Wpos" in self.state_dict:
raise ValueError("Wpos key in state dictionary")
self.state_dict = dict(self.state_dict)
if cfg["tied-embeddings-all"]:
cfg["tied-embeddings-src"] = True
cfg["tied-embeddings"] = True
self.share_encoder_decoder_embeddings = cfg["tied-embeddings-src"]
# create the tokenizer here because we need to know the eos_token_id
self.source_dir = source_dir
self.tokenizer = self.load_tokenizer()
# retrieve EOS token and set correctly
tokenizer_has_eos_token_id = (
hasattr(self.tokenizer, "eos_token_id") and self.tokenizer.eos_token_id is not None
)
eos_token_id = self.tokenizer.eos_token_id if tokenizer_has_eos_token_id else 0
if cfg["tied-embeddings-src"]:
self.wemb, self.final_bias = add_emb_entries(self.state_dict["Wemb"], self.state_dict[BIAS_KEY], 1)
self.pad_token_id = self.wemb.shape[0] - 1
cfg["vocab_size"] = self.pad_token_id + 1
else:
self.wemb, _ = add_emb_entries(self.state_dict["encoder_Wemb"], self.state_dict[BIAS_KEY], 1)
self.dec_wemb, self.final_bias = add_emb_entries(
self.state_dict["decoder_Wemb"], self.state_dict[BIAS_KEY], 1
)
# still assuming that vocab size is same for encoder and decoder
self.pad_token_id = self.wemb.shape[0] - 1
cfg["vocab_size"] = self.pad_token_id + 1
cfg["decoder_vocab_size"] = self.pad_token_id + 1
if cfg["vocab_size"] != self.tokenizer.vocab_size:
raise ValueError(
f"Original vocab size {cfg['vocab_size']} and new vocab size {len(self.tokenizer.encoder)} mismatched."
)
# self.state_dict['Wemb'].sha
self.state_keys = list(self.state_dict.keys())
if "Wtype" in self.state_dict:
raise ValueError("Wtype key in state dictionary")
self._check_layer_entries()
self.cfg = cfg
hidden_size, intermediate_shape = self.state_dict["encoder_l1_ffn_W1"].shape
if hidden_size != cfg["dim-emb"]:
raise ValueError(f"Hidden size {hidden_size} and configured size {cfg['dim_emb']} mismatched")
# Process decoder.yml
if (source_dir / "decoder.yml").exists():
decoder_yml = cast_marian_config(load_yaml_or_txt(source_dir / "decoder.yml"))
else:
decoder_yml = None
check_marian_cfg_assumptions(cfg)
self.hf_config = MarianConfig(
vocab_size=cfg["vocab_size"],
decoder_vocab_size=cfg.get("decoder_vocab_size", cfg["vocab_size"]),
share_encoder_decoder_embeddings=cfg["tied-embeddings-src"],
decoder_layers=cfg["dec-depth"],
encoder_layers=cfg["enc-depth"],
decoder_attention_heads=cfg["transformer-heads"],
encoder_attention_heads=cfg["transformer-heads"],
decoder_ffn_dim=cfg["transformer-dim-ffn"],
encoder_ffn_dim=cfg["transformer-dim-ffn"],
d_model=cfg["dim-emb"],
activation_function=cfg["transformer-ffn-activation"],
pad_token_id=self.pad_token_id,
eos_token_id=eos_token_id,
forced_eos_token_id=eos_token_id,
bos_token_id=0,
max_position_embeddings=cfg["dim-emb"],
scale_embedding=True,
normalize_embedding="n" in cfg["transformer-preprocess"],
static_position_embeddings=not cfg["transformer-train-position-embeddings"],
tie_word_embeddings=cfg["tied-embeddings"],
dropout=0.1, # see opus-mt-train repo/transformer-dropout param.
# default: add_final_layer_norm=False,
num_beams=decoder_yml["beam-size"] if decoder_yml is not None else 1,
decoder_start_token_id=self.pad_token_id,
bad_words_ids=[[self.pad_token_id]],
max_length=512,
)
def _check_layer_entries(self):
self.encoder_l1 = self.sub_keys("encoder_l1")
self.decoder_l1 = self.sub_keys("decoder_l1")
self.decoder_l2 = self.sub_keys("decoder_l2")
if len(self.encoder_l1) != 16:
warnings.warn(f"Expected 16 keys for each encoder layer, got {len(self.encoder_l1)}")
if len(self.decoder_l1) != 26:
warnings.warn(f"Expected 26 keys for each decoder layer, got {len(self.decoder_l1)}")
if len(self.decoder_l2) != 26:
warnings.warn(f"Expected 26 keys for each decoder layer, got {len(self.decoder_l1)}")
@property
def extra_keys(self):
extra = []
for k in self.state_keys:
if (
k.startswith("encoder_l")
or k.startswith("decoder_l")
or k in [CONFIG_KEY, "Wemb", "encoder_Wemb", "decoder_Wemb", "Wpos", "decoder_ff_logit_out_b"]
):
continue
else:
extra.append(k)
return extra
def sub_keys(self, layer_prefix):
return [remove_prefix(k, layer_prefix) for k in self.state_dict if k.startswith(layer_prefix)]
def load_tokenizer(self):
# save tokenizer
self.add_special_tokens_to_vocab(self.source_dir, not self.share_encoder_decoder_embeddings)
return MarianTokenizer.from_pretrained(str(self.source_dir))
def add_special_tokens_to_vocab(self, model_dir: Path, separate_vocab=False) -> None:
print(separate_vocab)
if separate_vocab:
vocab = load_yaml_or_txt(find_src_vocab_file(model_dir))
vocab = {k: int(v) for k, v in vocab.items()}
num_added = add_to_vocab_(vocab, ["<pad>"])
save_json(vocab, model_dir / "vocab.json")
vocab = load_yaml_or_txt(find_tgt_vocab_file(model_dir))
vocab = {k: int(v) for k, v in vocab.items()}
num_added = add_to_vocab_(vocab, ["<pad>"])
save_json(vocab, model_dir / "target_vocab.json")
self.save_tokenizer_config(model_dir, separate_vocabs=separate_vocab)
else:
vocab = load_yaml_or_txt(find_vocab_file(model_dir))
vocab = {k: int(v) for k, v in vocab.items()}
num_added = add_to_vocab_(vocab, ["<pad>"])
print(f"added {num_added} tokens to vocab")
save_json(vocab, model_dir / "vocab.json")
self.save_tokenizer_config(model_dir)
def save_tokenizer_config(self, dest_dir: Path, separate_vocabs=False):
dct = {
"target_lang": self.target_lang,
"source_lang": self.source_lang,
"source_spm": self.source_spm,
"target_spm": self.target_spm,
"separate_vocabs": separate_vocabs
}
save_json(dct, dest_dir / "tokenizer_config.json")
def load_marian_model(self) -> MarianMTModel:
state_dict, cfg = self.state_dict, self.hf_config
if not cfg.static_position_embeddings:
raise ValueError("config.static_position_embeddings should be True")
model = MarianMTModel(cfg)
if "hidden_size" in cfg.to_dict():
raise ValueError("hidden_size is in config")
load_layers_(
model.model.encoder.layers,
state_dict,
BART_CONVERTER,
)
load_layers_(model.model.decoder.layers, state_dict, BART_CONVERTER, is_decoder=True)
# handle tensors not associated with layers
if self.cfg["tied-embeddings-src"]:
wemb_tensor = nn.Parameter(torch.FloatTensor(self.wemb))
bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias))
assert model.model.shared is not None and model.model.encoder is not None
model.model.shared.weight = wemb_tensor
model.model.encoder.embed_tokens = model.model.decoder.embed_tokens = model.model.shared
else:
wemb_tensor = nn.Parameter(torch.FloatTensor(self.wemb))
model.model.encoder.embed_tokens.weight = wemb_tensor
decoder_wemb_tensor = nn.Parameter(torch.FloatTensor(self.dec_wemb))
bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias))
model.model.decoder.embed_tokens.weight = decoder_wemb_tensor
model.final_logits_bias = bias_tensor
if "Wpos" in state_dict:
print("Unexpected: got Wpos")
wpos_tensor = torch.tensor(state_dict["Wpos"])
model.model.encoder.embed_positions.weight = wpos_tensor
model.model.decoder.embed_positions.weight = wpos_tensor
if cfg.normalize_embedding:
if "encoder_emb_ln_scale_pre" not in state_dict:
raise ValueError("encoder_emb_ln_scale_pre is not in state dictionary")
raise NotImplementedError("Need to convert layernorm_embedding")
if self.extra_keys:
raise ValueError(f"Failed to convert {self.extra_keys}")
if model.get_input_embeddings().padding_idx != self.pad_token_id:
raise ValueError(
f"Padding tokens {model.get_input_embeddings().padding_idx} and {self.pad_token_id} mismatched"
)
return model
def convert(source_dir: Path, dest_dir: Path | str, source_lang: str, target_lang: str, source_spm: str, target_spm: str):
dest_dir = Path(dest_dir)
dest_dir.mkdir(exist_ok=True)
opus_state = OpusState(source_dir, source_lang, target_lang, source_spm, target_spm)
# save tokenizer
opus_state.tokenizer.save_pretrained(dest_dir)
# save_json(opus_state.cfg, dest_dir / "marian_original_config.json")
# ^^ Uncomment to save human readable marian config for debugging
model = opus_state.load_marian_model()
model = model.half()
model.save_pretrained(dest_dir)
model.from_pretrained(dest_dir) # sanity check
def load_yaml_or_txt(path: Path) -> dict:
import yaml
if path.suffix == ".yml":
with open(path, encoding="utf-8") as f:
return yaml.load(f, Loader=yaml.BaseLoader)
else:
with open(path, encoding="utf-8") as f:
return {k.split("\t")[0]: v for v, k in enumerate(f.read().strip().split("\n"))}
def save_json(content: Union[Dict, List], path: Path) -> None:
with open(path, "w") as f:
json.dump(content, f)
if __name__ == "__main__":
"""
Convert arbitrary Marian model to PyTorch.
"""
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--src", type=str, help="path to marian model sub dir")
parser.add_argument("--src-lang", type=str, help="source lang")
parser.add_argument("--tgt-lang", type=str, help="target lang")
parser.add_argument("--src-spm", type=str, default="source.spm", help="source sentencepiece model")
parser.add_argument("--tgt-spm", type=str, default="target.spm", help="target sentencepiece model")
parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model dir")
args = parser.parse_args()
source_dir = Path(args.src)
if not source_dir.exists():
raise ValueError(f"Source directory {source_dir} not found")
dest_dir = f"converted-{source_dir.name}" if args.dest is None else args.dest
convert(source_dir, dest_dir, args.src_lang, args.tgt_lang, args.src_spm, args.tgt_spm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment