Skip to content

Instantly share code, notes, and snippets.

@K024
Last active April 29, 2023 09:29
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save K024/4a100a0f4f4b07208958e0f3244da6ad to your computer and use it in GitHub Desktop.
Save K024/4a100a0f4f4b07208958e0f3244da6ad to your computer and use it in GitHub Desktop.
# trim.py
# trim the vocabulary of mt5 model in huggingface.co
# MIT License
# Copyright (c) 2022 K024
# %%
import torch
from tqdm.auto import tqdm
# %%
from transformers import T5Tokenizer
local = "./mt5-small"
target = "./mt5-trimmed"
tokenizer = T5Tokenizer.from_pretrained(local)
state_dict = torch.load(local + "/pytorch_model.bin")
# most_common = counter.most_common()[:85000]
# # 3 special tokens and 256 byte fallback
# keep_ids = sorted(set(range(259)) | set(x[0] for x in most_common))
keep_ids = torch.load("./keep_ids.pth")
# %%
from sentencepiece import sentencepiece_model_pb2 as spm
proto = spm.ModelProto()
with open(local + "/spiece.model", 'rb') as f:
proto.ParseFromString(f.read())
# %%
sp_target = spm.ModelProto()
with open(local + "/spiece.model", 'rb') as f:
sp_target.ParseFromString(f.read())
del sp_target.pieces[:]
# %%
shared_weight = state_dict['shared.weight']
lm_head = state_dict['lm_head.weight']
shared_weight_target = []
lm_head_target = []
# %%
for i, idx in enumerate(tqdm(keep_ids)):
assert len(sp_target.pieces) == i
assert len(shared_weight_target) == i
assert len(lm_head_target) == i
sp_target.pieces.append(proto.pieces[idx])
shared_weight_target.append(shared_weight[idx])
lm_head_target.append(lm_head[idx])
# <extra_id_xx>
for idx in range(250000, len(proto.pieces)):
sp_target.pieces.append(proto.pieces[idx])
shared_weight_target.append(shared_weight[idx])
lm_head_target.append(lm_head[idx])
# reserved for additional_special_tokens
for idx in range(len(proto.pieces), len(shared_weight)):
shared_weight_target.append(shared_weight[idx])
lm_head_target.append(lm_head[idx])
shared_weight_target = torch.stack(shared_weight_target)
lm_head_target = torch.stack(lm_head_target)
# %%
import os
os.makedirs(target, exist_ok=True)
state_dict['shared.weight'] = shared_weight_target
state_dict['encoder.embed_tokens.weight'] = shared_weight_target
state_dict['decoder.embed_tokens.weight'] = shared_weight_target
state_dict['lm_head.weight'] = lm_head_target
torch.save(state_dict, target + "/pytorch_model.bin")
with open(target + "/spiece.model", 'wb') as f:
f.write(sp_target.SerializeToString())
# %%
print(f"INFO: Copy config files into '{target}' dir and change vocab_size to {len(shared_weight_target)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment