Skip to content

Instantly share code, notes, and snippets.

@younesbelkada
Created November 8, 2022 15:04
Show Gist options
  • Save younesbelkada/2143991dc1f48740cdce3fd43cb0c1e0 to your computer and use it in GitHub Desktop.
Save younesbelkada/2143991dc1f48740cdce3fd43cb0c1e0 to your computer and use it in GitHub Desktop.
A script to save sequentially any `t5x` checkpoint
from typing import Dict, Union
from sqlalchemy import false
import torch
from transformers.utils.hub import convert_file_size_to_int
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import dtype_byte_size
import os
from transformers.models.switch_transformers.convert_switch_transformers_original_flax_checkpoint_to_pytorch import rename_keys
from flax.traverse_util import flatten_dict, unflatten_dict
from tensorflow.io import gfile
import tensorstore as ts
from flax import serialization
import json
import argparse
import gc
from memory_profiler import profile
import ctypes
def rename_base_flax_keys(flax_key_tuple, flax_tensor):
"""
Post renaming of basic JAX keys to pytorch.
"""
if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 3:
# expert layer
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_tensor = torch.permute(flax_tensor, ( 0, 2, 1))
elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple):
# linear layer
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_tensor = flax_tensor.T
elif flax_key_tuple[-1] in ["scale", "embedding"]:
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
return flax_key_tuple, flax_tensor
def get_key_and_tensorstore_dict(layer, checkpoint_info, switch_checkpoint_path):
if "metadata" in layer :
split_layer = layer.split("metadata")
curr_real_layer_name = "".join(split_layer[0])[:-1]
split_layer = [tuple(("metadata"+ split_layer[1]).split("/"))]
elif "kvstore" in layer :
split_layer = layer.split("kvstore")
curr_real_layer_name = "".join(split_layer[0])[:-1]
split_layer = [tuple(("kvstore"+ split_layer[1]).split("/"))]
else:
split_layer = layer.split("/")
curr_real_layer_name = "/".join(split_layer[:-1])
split_layer[-1] = (split_layer[-1],)
if "kvstore/path" in layer:
content = f"{switch_checkpoint_path}/{checkpoint_info[layer]}"
elif "kvstore/driver" in layer:
content = "file"
else :
content = checkpoint_info[layer]
return curr_real_layer_name, split_layer, content
def rename_and_save_block(current_block, save_path):
current_block = rename_keys(current_block)
new_current_block = {}
for k,v in current_block.items():
new_current_block[k.replace("/",".")] = v
current_block = new_current_block
torch.save(current_block, save_path)
return current_block
def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, index_start, index_end, weights_name: str = WEIGHTS_NAME):
max_shard_size = convert_file_size_to_int(max_shard_size)
sharded_state_dicts = []
current_block = {}
current_block_size = 0
total_size = 0
os.makedirs(dump_path,exist_ok=True)
with gfile.GFile(switch_checkpoint_path+"/checkpoint",'rb') as fp:
checkpoint_info = serialization.msgpack_restore(fp.read())["optimizer"]["target"]
checkpoint_info = flatten_dict(checkpoint_info, sep="/")
all_layers = {}
for layer in checkpoint_info.keys():
curr_real_layer_name, split_layer, content = get_key_and_tensorstore_dict(layer, checkpoint_info, switch_checkpoint_path)
if curr_real_layer_name in all_layers :
all_layers[curr_real_layer_name][split_layer[-1]] = content
else :
all_layers[curr_real_layer_name] = {split_layer[-1]: content}
if WEIGHTS_INDEX_NAME in os.listdir(dump_path):
weight_map = json.load(open(os.path.join(dump_path, WEIGHTS_INDEX_NAME), "r"))
offset_index = sum(["???" in file_name for file_name in os.listdir(dump_path)]) + 1
else:
weight_map = {}
offset_index = 1
keys_to_consider = list(all_layers.keys())[int(index_start*len(all_layers)):int(index_end*len(all_layers))]
list_modules_to_convert = {key:value for key, value in all_layers.items() if key in keys_to_consider}
del checkpoint_info
gc.collect()
dtype = getattr(torch,dtype)
for idx,key in enumerate(list_modules_to_convert.keys()):
# open tensorstore file
args = unflatten_dict(list_modules_to_convert[key])
raw_weights = ts.open(args, read = True).result()
raw_weights = raw_weights.__array__()
weight_tensor = torch.tensor(raw_weights).to(dtype)
del args, raw_weights
gc.collect()
libc = ctypes.CDLL("libc.so.6")
libc.malloc_trim(0)
weight_size = weight_tensor.numel() * dtype_byte_size(weight_tensor.dtype)
# use the renaming pattern from the small conversion scripts
key, weight_tensor = rename_base_flax_keys(tuple(key.split("/")), weight_tensor)
key = "/".join(key)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
save_path = os.path.join(dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+offset_index:05d}-of-???.bin"))
current_block = rename_and_save_block(current_block, save_path)
sharded_state_dicts.append(current_block.keys())
libc.malloc_trim(0)
del current_block
gc.collect()
current_block = {}
current_block_size = 0
current_block[key] = weight_tensor.clone()
current_block_size += weight_size
total_size += weight_size
del weight_tensor
gc.collect()
libc.malloc_trim(0)
# Add the last block
save_path = os.path.join(dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+offset_index:05d}-of-???.bin"))
rename_and_save_block(current_block, save_path)
sharded_state_dicts.append(current_block.keys())
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
for key in shard:
pytorch_key = key.replace("/", ".")
if "weight_map" in weight_map.keys():
weight_map["weight_map"][pytorch_key] = weights_name.replace(".bin", f"-{idx+offset_index:05d}-of-???.bin")
else:
weight_map[pytorch_key] = weights_name.replace(".bin", f"-{idx+offset_index:05d}-of-???.bin")
# Add the metadata
if "metadata" in weight_map.keys():
weight_map["metadata"]["total_size"] += total_size
index = weight_map
else:
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
with open(os.path.join(dump_path,WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--switch_t5x_checkpoint_path",default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128/checkpoint_634600",type=str,required=False,help=("Path to a directory containing a folder per layer. Follows the original Google format."),)
parser.add_argument("--max_shard_size", default="10GB", required=False, help="Max shard size")
parser.add_argument("--index_start", default=0.0, required=True, help="Start index", type=float)
parser.add_argument("--index_end", default=0.2, required=True, help="End index", type=float)
parser.add_argument("--dtype", default="bfloat16", type=str, required=False, help="dtype of the saved model")
parser.add_argument("--pytorch_dump_folder_path", default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128-converted", type=str, required=False, help="Path to the output pytorch model.")
args = parser.parse_args()
shard_on_the_fly(
args.switch_t5x_checkpoint_path,
args.pytorch_dump_folder_path,
args.max_shard_size,
args.dtype,
args.index_start,
args.index_end,
)
def sanity_check():
from transformers import SwitchTransformersForConditionalGeneration, SwitchTransformersConfig, T5Tokenizer
config = SwitchTransformersConfig.from_pretrained("google/switch-base-8")
config.save_pretrained("/home/younes_huggingface_co/test")
model = SwitchTransformersForConditionalGeneration.from_pretrained("/home/younes_huggingface_co/test", device_map = "auto")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
input_ids = tokenizer(text, return_tensors="pt").input_ids
out = model.generate(input_ids, decoder_start_token_id=0)
print(tokenizer.decode(out[0]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment