Skip to content

Instantly share code, notes, and snippets.

@alumae
Created March 20, 2023 13:40
Show Gist options
  • Save alumae/2dcf473b667cec9d513b80ea24e94672 to your computer and use it in GitHub Desktop.
Save alumae/2dcf473b667cec9d513b80ea24e94672 to your computer and use it in GitHub Desktop.
import argparse
import hashlib
import os
import urllib
import warnings
from collections import OrderedDict
import torch
from torch import nn
from tqdm import tqdm
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer
def remove_ignore_keys_(state_dict):
ignore_keys = ["layers", "blocks"]
for k in ignore_keys:
state_dict.pop(k, None)
WHISPER_MAPPING = OrderedDict([
("decoder.decoders", "decoder"),
("encoder.encoders", "encoder"),
("blocks", "layers"),
("mlp.0", "fc1"),
("mlp.2", "fc2"),
("mlp_ln", "final_layer_norm"),
(".attn.query", ".self_attn.q_proj"),
(".attn.key", ".self_attn.k_proj"),
(".attn.value", ".self_attn.v_proj"),
(".attn_ln", ".self_attn_layer_norm"),
(".attn.out", ".self_attn.out_proj"),
(".cross_attn.query", ".encoder_attn.q_proj"),
(".cross_attn.key", ".encoder_attn.k_proj"),
(".cross_attn.value", ".encoder_attn.v_proj"),
(".cross_attn_ln", ".encoder_attn_layer_norm"),
(".cross_attn.out", ".encoder_attn.out_proj"),
("decoder.ln.", "decoder.layer_norm."),
("encoder.ln.", "encoder.layer_norm."),
("token_embedding", "embed_tokens"),
("encoder.positional_embedding", "encoder.embed_positions.weight"),
("decoder.positional_embedding", "decoder.embed_positions.weight"),
("ln_post", "layer_norm"),
])
def rename_keys(s_dict):
keys = list(s_dict.keys())
for key in keys:
new_key = key
for k, v in WHISPER_MAPPING.items():
if k in new_key:
new_key = new_key.replace(k, v)
print(f"{key} -> {new_key}")
s_dict[new_key] = s_dict.pop(key)
return s_dict
def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
def convert_espnet_whisper_to_tfms(espnet_checkpoint, pytorch_dump_folder_path, whisper_config_id):
state_dict = torch.load(espnet_checkpoint, map_location="cpu")
proj_out_weights = state_dict["decoder.decoders.token_embedding.weight"]
remove_ignore_keys_(state_dict)
rename_keys(state_dict)
tie_embeds = True
#ffn_dim = state_dict["decoder.layers.0.fc1.weight"].shape[0]
config = WhisperConfig.from_pretrained(whisper_config_id)
model = WhisperForConditionalGeneration(config)
missing, unexpected = model.model.load_state_dict(state_dict, strict=False)
if len(missing) > 0 and not set(missing) <= {
"encoder.embed_positions.weights",
"decoder.embed_positions.weights",
}:
raise ValueError(
"Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing,"
f" but all the following weights are missing {missing}"
)
if tie_embeds:
model.proj_out = make_linear_from_emb(model.model.decoder.embed_tokens)
else:
model.proj_out.weight.data = proj_out_weights
model.save_pretrained(pytorch_dump_folder_path)
tokenizer = WhisperTokenizer.from_pretrained(whisper_config_id)
tokenizer.save_pretrained(pytorch_dump_folder_path)
processor = WhisperProcessor.from_pretrained(whisper_config_id)
processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# # Required parameters
parser.add_argument("--whisper-config-id", required=True, type=str, help="Whisper config ID, e.g. openai/whisper-medium")
parser.add_argument("--espnet_checkpoint", required=True, type=str, help="Patht to the Espnet model checkpoint")
parser.add_argument("--pytorch_dump_folder_path", required=True, type=str, help="Path to the output PyTorch model in HuggingFace format")
args = parser.parse_args()
convert_espnet_whisper_to_tfms(args.espnet_checkpoint, args.pytorch_dump_folder_path, args.whisper_config_id)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment