Created
December 22, 2022 22:07
-
-
Save younesbelkada/e0bbf0da613e496925baf85294622da1 to your computer and use it in GitHub Desktop.
Handy script to convert any `t5x` checkpoint to an un-nested dictionary
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from t5x import checkpoints | |
from flax.traverse_util import flatten_dict, unflatten_dict | |
def rename_keys(key): | |
if "kernel" in key: | |
key = key.replace("kernel", "weight") | |
return key | |
flax_checkpoint_path = "/home/younes_huggingface_co/code/pix2struct/pix2struct_base" | |
flax_params = checkpoints.load_t5x_checkpoint(flax_checkpoint_path) | |
flattened_dict = flatten_dict(flax_params['target'], sep="/") | |
keys = list(flattened_dict.keys()) | |
for key in keys: | |
new_key = rename_keys(key) | |
tensor = flattened_dict.pop(key) | |
flattened_dict[new_key] = tensor | |
print(flattened_dict) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment