Skip to content

Instantly share code, notes, and snippets.

@MatthewK78
Last active August 25, 2024 09:56
Show Gist options
  • Save MatthewK78/6d946ed5736f3222603411fb80108c41 to your computer and use it in GitHub Desktop.
Save MatthewK78/6d946ed5736f3222603411fb80108c41 to your computer and use it in GitHub Desktop.
Converts CLIP text encoder models to HuggingFace format. It supports merging from donor models, type conversion, and dimension mismatch handling.
"""
CLIP Text Encoder Converter
Converts CLIP text encoder models to HuggingFace format. It supports merging from donor models, type conversion, and dimension mismatch handling.
Key Features:
- SafeTensors file format support
- Key renaming for HuggingFace compatibility
- Optional donor model merging
- Customizable output data type and device
- Dimension mismatch handling (replace or interpolate)
Author: Matthew E. Kieren
Created: August 24th, 2024
Version: 0.1
Usage: python clip_text_converter.py [--help] [--dtype] [--device] [--overwrite] [--match] [--replace] [--interpolate] [-no_metadata] [-print_metadata] input.safetensors output.safetensors [donor.safetensors]
"""
from sys import exit
import argparse
import torch
from safetensors.torch import safe_open, save
import torch.nn.functional as F
def load_model_keys(model_path):
with safe_open(model_path, framework='pt', device=args.device) as f:
return {key: torch.empty(f.get_tensor(key).shape, dtype=f.get_tensor(key).dtype, device='meta') for key in f.keys()}
def rename_keys(input_keys):
renamed_keys = {}
for input_key in input_keys:
if input_key.startswith("transformer.resblocks."):
parts = input_key.split(".")
layer_num = int(parts[2])
renamed_key = f"text_model.encoder.layers.{layer_num}.{'.'.join(parts[3:])}"
renamed_key = renamed_key.replace("attn.in_proj", "self_attn")
renamed_key = renamed_key.replace("attn.out_proj", "self_attn.out_proj")
renamed_key = renamed_key.replace("mlp.c_fc", "mlp.fc1")
renamed_key = renamed_key.replace("mlp.c_proj", "mlp.fc2")
renamed_key = renamed_key.replace("ln_1", "layer_norm1")
renamed_key = renamed_key.replace("ln_2", "layer_norm2")
if "self_attn.weight" in renamed_key:
renamed_keys[f"{input_key}_q"] = renamed_key.replace("self_attn.weight", "self_attn.q_proj.weight")
renamed_keys[f"{input_key}_k"] = renamed_key.replace("self_attn.weight", "self_attn.k_proj.weight")
renamed_keys[f"{input_key}_v"] = renamed_key.replace("self_attn.weight", "self_attn.v_proj.weight")
elif "self_attn.bias" in renamed_key:
renamed_keys[f"{input_key}_q"] = renamed_key.replace("self_attn.bias", "self_attn.q_proj.bias")
renamed_keys[f"{input_key}_k"] = renamed_key.replace("self_attn.bias", "self_attn.k_proj.bias")
renamed_keys[f"{input_key}_v"] = renamed_key.replace("self_attn.bias", "self_attn.v_proj.bias")
else:
renamed_keys[input_key] = renamed_key
elif input_key == "ln_final.weight":
renamed_keys[input_key] = "text_model.final_layer_norm.weight"
elif input_key == "ln_final.bias":
renamed_keys[input_key] = "text_model.final_layer_norm.bias"
elif input_key == "token_embedding.weight":
renamed_keys[input_key] = "text_model.embeddings.token_embedding.weight"
elif input_key == "positional_embedding":
renamed_keys[input_key] = "text_model.embeddings.position_embedding.weight"
elif input_key == "text_projection":
renamed_keys[input_key] = "text_projection.weight"
elif input_key == "logit_scale":
renamed_keys[input_key] = "logit_scale"
return renamed_keys
def main(args):
# Appends '.safetensors' if missing in filename
args.input += '.safetensors' if not args.input.lower().endswith('.safetensors') else ''
if args.donor:
args.donor += '.safetensors' if not args.donor.lower().endswith('.safetensors') else ''
args.output += '.safetensors' if not args.output.lower().endswith('.safetensors') else ''
print(f"\nDevice: {args.device}")
if args.dtype:
print(f" Type: {getattr(torch, args.dtype)}")
print(f" Input: {args.input}")
if args.donor:
print(f" Donor: {args.donor}")
print(f"Output: {'[OVERWRITE] ' if args.overwrite else ''}{args.output}\n")
input_keys = load_model_keys(args.input)
if args.donor:
donor_keys = load_model_keys(args.donor)
output_tensors = {}
metadata = {}
with safe_open(args.input, framework='pt', device=args.device) as f:
metadata = f.metadata()
for input_key, new_key in rename_keys(input_keys).items():
output_tensors[new_key] = f.get_tensor(input_key).to(dtype=getattr(torch, args.dtype), device=args.device) if args.dtype else f.get_tensor(input_key).to(device=args.device)
if args.print_metadata:
print(f"Metadata: {metadata}\n")
if args.donor:
with safe_open(args.donor, framework='pt', device=args.device) as f:
for donor_key in donor_keys:
if donor_key in output_tensors and donor_keys[donor_key].shape != output_tensors[donor_key].shape:
print(f"⚠️ WARNING: Dimension mismatch for {donor_key}, {'replacing ' if args.replace else 'resizing ' if args.interpolate else ''}Input: {list(output_tensors[donor_key].shape)}{' with' if args.replace else ' to' if args.interpolate else ','} Donor: {list(donor_keys[donor_key].shape)}")
if args.replace:
output_tensors[donor_key] = f.get_tensor(donor_key).to(dtype=getattr(torch, args.dtype), device=args.device) if args.dtype else f.get_tensor(donor_key).to(device=args.device)
continue
elif args.interpolate:
output_tensors[donor_key] = F.interpolate(
output_tensors[donor_key].unsqueeze(0).unsqueeze(0).to(dtype=torch.float64),
size=donor_keys[donor_key].shape,
mode='nearest-exact',
).squeeze(0).squeeze(0).to(dtype=output_tensors[donor_key].dtype, device=output_tensors[donor_key].device)
continue
if not donor_key in output_tensors:
output_tensors[donor_key] = f.get_tensor(donor_key).to(dtype=getattr(torch, args.dtype), device=args.device) if args.dtype else f.get_tensor(donor_key).to(device=args.device)
if args.match and args.donor:
output_tensors = {k: v for k, v in output_tensors.items() if k in donor_keys}
print(f"\n Keys: {len(output_tensors)} total")
open(args.output, 'xb' if not args.overwrite else 'wb').write(save(output_tensors, metadata=metadata if not args.no_metadata else None))
print(f" Saved: {args.output}\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert CLIP text encoder keys to HuggingFace format, discards all other keys")
parser.add_argument("input", help="Input filename [appends '.safetensors' if missing]")
parser.add_argument("output", help="Output filename [appends '.safetensors' if missing]")
parser.add_argument("donor", nargs='?', default=None, help="Donor safetensors file, for adding missing keys [optional]")
parser.add_argument("-t", "--dtype", help="Output data type [default keeps existing input/donor dtype]")
parser.add_argument("-d", "--device", default="cpu", help="Device to use for tensor operations [default is cpu]")
parser.add_argument("-ow", "--overwrite", action="store_true", help="Force overwriting output file if it already exists")
parser.add_argument("-m", "--match", action="store_true", help="Ensure output only has keys matching the donor model")
parser.add_argument("-r", "--replace", action="store_true", help="Replace input weights with donor weights when dimensions mismatch")
parser.add_argument("-i", "--interpolate", action="store_true", help="Resize mismatched dimensions of input with dimensions of donor")
parser.add_argument("-no_md", "--no_metadata", action="store_true", help="Skip copying metadata from input model")
parser.add_argument("-p", "--print_metadata", action="store_true", help="Print metadata from input file")
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment