Skip to content

Instantly share code, notes, and snippets.

@younesbelkada
Created December 22, 2022 22:07
Show Gist options
  • Save younesbelkada/e0bbf0da613e496925baf85294622da1 to your computer and use it in GitHub Desktop.
Save younesbelkada/e0bbf0da613e496925baf85294622da1 to your computer and use it in GitHub Desktop.
Handy script to convert any `t5x` checkpoint to an un-nested dictionary
from t5x import checkpoints
from flax.traverse_util import flatten_dict, unflatten_dict
def rename_keys(key):
if "kernel" in key:
key = key.replace("kernel", "weight")
return key
flax_checkpoint_path = "/home/younes_huggingface_co/code/pix2struct/pix2struct_base"
flax_params = checkpoints.load_t5x_checkpoint(flax_checkpoint_path)
flattened_dict = flatten_dict(flax_params['target'], sep="/")
keys = list(flattened_dict.keys())
for key in keys:
new_key = rename_keys(key)
tensor = flattened_dict.pop(key)
flattened_dict[new_key] = tensor
print(flattened_dict)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment