Skip to content

Instantly share code, notes, and snippets.

@UniDyne
Created June 13, 2023 01:10
Show Gist options
  • Save UniDyne/acd70e52e91472753cbd7a23611f44d9 to your computer and use it in GitHub Desktop.
Save UniDyne/acd70e52e91472753cbd7a23611f44d9 to your computer and use it in GitHub Desktop.
"""
# Stable Diffusion Embedding Converter
This is a simple script that converts a `.pt` Textual Inversion embedding file to `.safetensors` format. Nothing more, nothing less.
## To Use
```
$ python convert_embedding.py embeddings/myembed.pt embeddings/myembed.safetensors
Trained on v1-5-pruned-emaonly.
Trained for 6808 steps.
Dimensions of embedding: torch.Size([12, 768])
```
"""
import os
import argparse
from pathlib import Path
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from safetensors.torch import safe_open
from safetensors.torch import save_file as safe_save
def convert(path, outpath, overwrite=False):
# check it's not there already
if os.path.exists(outpath) and not overwrite:
raise ValueError(
f"Output path {outpath} already exists, and overwrite is not True"
)
# Load model and extract the embedding
model = torch.load(path)
model_tensors = model.get('string_to_param').get('*')
s_model = {
'emb_params': model_tensors
}
# Print the checkpoint name, if defined
if ('sd_checkpoint_name' in model) and (model['sd_checkpoint_name'] is not None):
print(f"Trained on {model['sd_checkpoint_name']}.")
else:
print("Checkpoint name not found in the model.")
# Print the number of training steps
if ('step' in model) and (model['step'] is not None):
print(f"Trained for {model['step']} steps.")
else:
print("Step not found in the model.")
# Display the tensor shape
print(f"Dimensions of embedding: {model_tensors.shape}")
print()
safe_save(s_model, outpath)
def main(args_in: Optional[List[str]] = None) -> None:
parser = argparse.ArgumentParser(description="Convert embedding to safetensor.")
parser.add_argument("model", type=Path, help="Embedding .pt file input")
parser.add_argument("output", type=Path, help="Embedding .safetensors file output")
args = parser.parse_args(args_in)
convert(args.model, args.output)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment