Created
November 8, 2022 15:04
-
-
Save younesbelkada/2143991dc1f48740cdce3fd43cb0c1e0 to your computer and use it in GitHub Desktop.
A script to save sequentially any `t5x` checkpoint
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, Union | |
from sqlalchemy import false | |
import torch | |
from transformers.utils.hub import convert_file_size_to_int | |
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME | |
from transformers.modeling_utils import dtype_byte_size | |
import os | |
from transformers.models.switch_transformers.convert_switch_transformers_original_flax_checkpoint_to_pytorch import rename_keys | |
from flax.traverse_util import flatten_dict, unflatten_dict | |
from tensorflow.io import gfile | |
import tensorstore as ts | |
from flax import serialization | |
import json | |
import argparse | |
import gc | |
from memory_profiler import profile | |
import ctypes | |
def rename_base_flax_keys(flax_key_tuple, flax_tensor): | |
""" | |
Post renaming of basic JAX keys to pytorch. | |
""" | |
if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 3: | |
# expert layer | |
flax_key_tuple = flax_key_tuple[:-1] + ("weight",) | |
flax_tensor = torch.permute(flax_tensor, ( 0, 2, 1)) | |
elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple): | |
# linear layer | |
flax_key_tuple = flax_key_tuple[:-1] + ("weight",) | |
flax_tensor = flax_tensor.T | |
elif flax_key_tuple[-1] in ["scale", "embedding"]: | |
flax_key_tuple = flax_key_tuple[:-1] + ("weight",) | |
return flax_key_tuple, flax_tensor | |
def get_key_and_tensorstore_dict(layer, checkpoint_info, switch_checkpoint_path): | |
if "metadata" in layer : | |
split_layer = layer.split("metadata") | |
curr_real_layer_name = "".join(split_layer[0])[:-1] | |
split_layer = [tuple(("metadata"+ split_layer[1]).split("/"))] | |
elif "kvstore" in layer : | |
split_layer = layer.split("kvstore") | |
curr_real_layer_name = "".join(split_layer[0])[:-1] | |
split_layer = [tuple(("kvstore"+ split_layer[1]).split("/"))] | |
else: | |
split_layer = layer.split("/") | |
curr_real_layer_name = "/".join(split_layer[:-1]) | |
split_layer[-1] = (split_layer[-1],) | |
if "kvstore/path" in layer: | |
content = f"{switch_checkpoint_path}/{checkpoint_info[layer]}" | |
elif "kvstore/driver" in layer: | |
content = "file" | |
else : | |
content = checkpoint_info[layer] | |
return curr_real_layer_name, split_layer, content | |
def rename_and_save_block(current_block, save_path): | |
current_block = rename_keys(current_block) | |
new_current_block = {} | |
for k,v in current_block.items(): | |
new_current_block[k.replace("/",".")] = v | |
current_block = new_current_block | |
torch.save(current_block, save_path) | |
return current_block | |
def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, index_start, index_end, weights_name: str = WEIGHTS_NAME): | |
max_shard_size = convert_file_size_to_int(max_shard_size) | |
sharded_state_dicts = [] | |
current_block = {} | |
current_block_size = 0 | |
total_size = 0 | |
os.makedirs(dump_path,exist_ok=True) | |
with gfile.GFile(switch_checkpoint_path+"/checkpoint",'rb') as fp: | |
checkpoint_info = serialization.msgpack_restore(fp.read())["optimizer"]["target"] | |
checkpoint_info = flatten_dict(checkpoint_info, sep="/") | |
all_layers = {} | |
for layer in checkpoint_info.keys(): | |
curr_real_layer_name, split_layer, content = get_key_and_tensorstore_dict(layer, checkpoint_info, switch_checkpoint_path) | |
if curr_real_layer_name in all_layers : | |
all_layers[curr_real_layer_name][split_layer[-1]] = content | |
else : | |
all_layers[curr_real_layer_name] = {split_layer[-1]: content} | |
if WEIGHTS_INDEX_NAME in os.listdir(dump_path): | |
weight_map = json.load(open(os.path.join(dump_path, WEIGHTS_INDEX_NAME), "r")) | |
offset_index = sum(["???" in file_name for file_name in os.listdir(dump_path)]) + 1 | |
else: | |
weight_map = {} | |
offset_index = 1 | |
keys_to_consider = list(all_layers.keys())[int(index_start*len(all_layers)):int(index_end*len(all_layers))] | |
list_modules_to_convert = {key:value for key, value in all_layers.items() if key in keys_to_consider} | |
del checkpoint_info | |
gc.collect() | |
dtype = getattr(torch,dtype) | |
for idx,key in enumerate(list_modules_to_convert.keys()): | |
# open tensorstore file | |
args = unflatten_dict(list_modules_to_convert[key]) | |
raw_weights = ts.open(args, read = True).result() | |
raw_weights = raw_weights.__array__() | |
weight_tensor = torch.tensor(raw_weights).to(dtype) | |
del args, raw_weights | |
gc.collect() | |
libc = ctypes.CDLL("libc.so.6") | |
libc.malloc_trim(0) | |
weight_size = weight_tensor.numel() * dtype_byte_size(weight_tensor.dtype) | |
# use the renaming pattern from the small conversion scripts | |
key, weight_tensor = rename_base_flax_keys(tuple(key.split("/")), weight_tensor) | |
key = "/".join(key) | |
# If this weight is going to tip up over the maximal size, we split. | |
if current_block_size + weight_size > max_shard_size: | |
save_path = os.path.join(dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+offset_index:05d}-of-???.bin")) | |
current_block = rename_and_save_block(current_block, save_path) | |
sharded_state_dicts.append(current_block.keys()) | |
libc.malloc_trim(0) | |
del current_block | |
gc.collect() | |
current_block = {} | |
current_block_size = 0 | |
current_block[key] = weight_tensor.clone() | |
current_block_size += weight_size | |
total_size += weight_size | |
del weight_tensor | |
gc.collect() | |
libc.malloc_trim(0) | |
# Add the last block | |
save_path = os.path.join(dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+offset_index:05d}-of-???.bin")) | |
rename_and_save_block(current_block, save_path) | |
sharded_state_dicts.append(current_block.keys()) | |
shards = {} | |
for idx, shard in enumerate(sharded_state_dicts): | |
for key in shard: | |
pytorch_key = key.replace("/", ".") | |
if "weight_map" in weight_map.keys(): | |
weight_map["weight_map"][pytorch_key] = weights_name.replace(".bin", f"-{idx+offset_index:05d}-of-???.bin") | |
else: | |
weight_map[pytorch_key] = weights_name.replace(".bin", f"-{idx+offset_index:05d}-of-???.bin") | |
# Add the metadata | |
if "metadata" in weight_map.keys(): | |
weight_map["metadata"]["total_size"] += total_size | |
index = weight_map | |
else: | |
metadata = {"total_size": total_size} | |
index = {"metadata": metadata, "weight_map": weight_map} | |
with open(os.path.join(dump_path,WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: | |
content = json.dumps(index, indent=2, sort_keys=True) + "\n" | |
f.write(content) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# Required parameters | |
parser.add_argument("--switch_t5x_checkpoint_path",default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128/checkpoint_634600",type=str,required=False,help=("Path to a directory containing a folder per layer. Follows the original Google format."),) | |
parser.add_argument("--max_shard_size", default="10GB", required=False, help="Max shard size") | |
parser.add_argument("--index_start", default=0.0, required=True, help="Start index", type=float) | |
parser.add_argument("--index_end", default=0.2, required=True, help="End index", type=float) | |
parser.add_argument("--dtype", default="bfloat16", type=str, required=False, help="dtype of the saved model") | |
parser.add_argument("--pytorch_dump_folder_path", default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128-converted", type=str, required=False, help="Path to the output pytorch model.") | |
args = parser.parse_args() | |
shard_on_the_fly( | |
args.switch_t5x_checkpoint_path, | |
args.pytorch_dump_folder_path, | |
args.max_shard_size, | |
args.dtype, | |
args.index_start, | |
args.index_end, | |
) | |
def sanity_check(): | |
from transformers import SwitchTransformersForConditionalGeneration, SwitchTransformersConfig, T5Tokenizer | |
config = SwitchTransformersConfig.from_pretrained("google/switch-base-8") | |
config.save_pretrained("/home/younes_huggingface_co/test") | |
model = SwitchTransformersForConditionalGeneration.from_pretrained("/home/younes_huggingface_co/test", device_map = "auto") | |
tokenizer = T5Tokenizer.from_pretrained("t5-small") | |
text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>." | |
input_ids = tokenizer(text, return_tensors="pt").input_ids | |
out = model.generate(input_ids, decoder_start_token_id=0) | |
print(tokenizer.decode(out[0])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment