Last active
August 25, 2024 09:56
-
-
Save MatthewK78/6d946ed5736f3222603411fb80108c41 to your computer and use it in GitHub Desktop.
Converts CLIP text encoder models to HuggingFace format. It supports merging from donor models, type conversion, and dimension mismatch handling.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
CLIP Text Encoder Converter | |
Converts CLIP text encoder models to HuggingFace format. It supports merging from donor models, type conversion, and dimension mismatch handling. | |
Key Features: | |
- SafeTensors file format support | |
- Key renaming for HuggingFace compatibility | |
- Optional donor model merging | |
- Customizable output data type and device | |
- Dimension mismatch handling (replace or interpolate) | |
Author: Matthew E. Kieren | |
Created: August 24th, 2024 | |
Version: 0.1 | |
Usage: python clip_text_converter.py [--help] [--dtype] [--device] [--overwrite] [--match] [--replace] [--interpolate] [-no_metadata] [-print_metadata] input.safetensors output.safetensors [donor.safetensors] | |
""" | |
from sys import exit | |
import argparse | |
import torch | |
from safetensors.torch import safe_open, save | |
import torch.nn.functional as F | |
def load_model_keys(model_path): | |
with safe_open(model_path, framework='pt', device=args.device) as f: | |
return {key: torch.empty(f.get_tensor(key).shape, dtype=f.get_tensor(key).dtype, device='meta') for key in f.keys()} | |
def rename_keys(input_keys): | |
renamed_keys = {} | |
for input_key in input_keys: | |
if input_key.startswith("transformer.resblocks."): | |
parts = input_key.split(".") | |
layer_num = int(parts[2]) | |
renamed_key = f"text_model.encoder.layers.{layer_num}.{'.'.join(parts[3:])}" | |
renamed_key = renamed_key.replace("attn.in_proj", "self_attn") | |
renamed_key = renamed_key.replace("attn.out_proj", "self_attn.out_proj") | |
renamed_key = renamed_key.replace("mlp.c_fc", "mlp.fc1") | |
renamed_key = renamed_key.replace("mlp.c_proj", "mlp.fc2") | |
renamed_key = renamed_key.replace("ln_1", "layer_norm1") | |
renamed_key = renamed_key.replace("ln_2", "layer_norm2") | |
if "self_attn.weight" in renamed_key: | |
renamed_keys[f"{input_key}_q"] = renamed_key.replace("self_attn.weight", "self_attn.q_proj.weight") | |
renamed_keys[f"{input_key}_k"] = renamed_key.replace("self_attn.weight", "self_attn.k_proj.weight") | |
renamed_keys[f"{input_key}_v"] = renamed_key.replace("self_attn.weight", "self_attn.v_proj.weight") | |
elif "self_attn.bias" in renamed_key: | |
renamed_keys[f"{input_key}_q"] = renamed_key.replace("self_attn.bias", "self_attn.q_proj.bias") | |
renamed_keys[f"{input_key}_k"] = renamed_key.replace("self_attn.bias", "self_attn.k_proj.bias") | |
renamed_keys[f"{input_key}_v"] = renamed_key.replace("self_attn.bias", "self_attn.v_proj.bias") | |
else: | |
renamed_keys[input_key] = renamed_key | |
elif input_key == "ln_final.weight": | |
renamed_keys[input_key] = "text_model.final_layer_norm.weight" | |
elif input_key == "ln_final.bias": | |
renamed_keys[input_key] = "text_model.final_layer_norm.bias" | |
elif input_key == "token_embedding.weight": | |
renamed_keys[input_key] = "text_model.embeddings.token_embedding.weight" | |
elif input_key == "positional_embedding": | |
renamed_keys[input_key] = "text_model.embeddings.position_embedding.weight" | |
elif input_key == "text_projection": | |
renamed_keys[input_key] = "text_projection.weight" | |
elif input_key == "logit_scale": | |
renamed_keys[input_key] = "logit_scale" | |
return renamed_keys | |
def main(args): | |
# Appends '.safetensors' if missing in filename | |
args.input += '.safetensors' if not args.input.lower().endswith('.safetensors') else '' | |
if args.donor: | |
args.donor += '.safetensors' if not args.donor.lower().endswith('.safetensors') else '' | |
args.output += '.safetensors' if not args.output.lower().endswith('.safetensors') else '' | |
print(f"\nDevice: {args.device}") | |
if args.dtype: | |
print(f" Type: {getattr(torch, args.dtype)}") | |
print(f" Input: {args.input}") | |
if args.donor: | |
print(f" Donor: {args.donor}") | |
print(f"Output: {'[OVERWRITE] ' if args.overwrite else ''}{args.output}\n") | |
input_keys = load_model_keys(args.input) | |
if args.donor: | |
donor_keys = load_model_keys(args.donor) | |
output_tensors = {} | |
metadata = {} | |
with safe_open(args.input, framework='pt', device=args.device) as f: | |
metadata = f.metadata() | |
for input_key, new_key in rename_keys(input_keys).items(): | |
output_tensors[new_key] = f.get_tensor(input_key).to(dtype=getattr(torch, args.dtype), device=args.device) if args.dtype else f.get_tensor(input_key).to(device=args.device) | |
if args.print_metadata: | |
print(f"Metadata: {metadata}\n") | |
if args.donor: | |
with safe_open(args.donor, framework='pt', device=args.device) as f: | |
for donor_key in donor_keys: | |
if donor_key in output_tensors and donor_keys[donor_key].shape != output_tensors[donor_key].shape: | |
print(f"⚠️ WARNING: Dimension mismatch for {donor_key}, {'replacing ' if args.replace else 'resizing ' if args.interpolate else ''}Input: {list(output_tensors[donor_key].shape)}{' with' if args.replace else ' to' if args.interpolate else ','} Donor: {list(donor_keys[donor_key].shape)}") | |
if args.replace: | |
output_tensors[donor_key] = f.get_tensor(donor_key).to(dtype=getattr(torch, args.dtype), device=args.device) if args.dtype else f.get_tensor(donor_key).to(device=args.device) | |
continue | |
elif args.interpolate: | |
output_tensors[donor_key] = F.interpolate( | |
output_tensors[donor_key].unsqueeze(0).unsqueeze(0).to(dtype=torch.float64), | |
size=donor_keys[donor_key].shape, | |
mode='nearest-exact', | |
).squeeze(0).squeeze(0).to(dtype=output_tensors[donor_key].dtype, device=output_tensors[donor_key].device) | |
continue | |
if not donor_key in output_tensors: | |
output_tensors[donor_key] = f.get_tensor(donor_key).to(dtype=getattr(torch, args.dtype), device=args.device) if args.dtype else f.get_tensor(donor_key).to(device=args.device) | |
if args.match and args.donor: | |
output_tensors = {k: v for k, v in output_tensors.items() if k in donor_keys} | |
print(f"\n Keys: {len(output_tensors)} total") | |
open(args.output, 'xb' if not args.overwrite else 'wb').write(save(output_tensors, metadata=metadata if not args.no_metadata else None)) | |
print(f" Saved: {args.output}\n") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Convert CLIP text encoder keys to HuggingFace format, discards all other keys") | |
parser.add_argument("input", help="Input filename [appends '.safetensors' if missing]") | |
parser.add_argument("output", help="Output filename [appends '.safetensors' if missing]") | |
parser.add_argument("donor", nargs='?', default=None, help="Donor safetensors file, for adding missing keys [optional]") | |
parser.add_argument("-t", "--dtype", help="Output data type [default keeps existing input/donor dtype]") | |
parser.add_argument("-d", "--device", default="cpu", help="Device to use for tensor operations [default is cpu]") | |
parser.add_argument("-ow", "--overwrite", action="store_true", help="Force overwriting output file if it already exists") | |
parser.add_argument("-m", "--match", action="store_true", help="Ensure output only has keys matching the donor model") | |
parser.add_argument("-r", "--replace", action="store_true", help="Replace input weights with donor weights when dimensions mismatch") | |
parser.add_argument("-i", "--interpolate", action="store_true", help="Resize mismatched dimensions of input with dimensions of donor") | |
parser.add_argument("-no_md", "--no_metadata", action="store_true", help="Skip copying metadata from input model") | |
parser.add_argument("-p", "--print_metadata", action="store_true", help="Print metadata from input file") | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment