Skip to content

Instantly share code, notes, and snippets.

@miabrahams
Created August 19, 2023 21:35
Show Gist options
  • Save miabrahams/e9aa899669e6d757f7c368da6a1fde56 to your computer and use it in GitHub Desktop.
Save miabrahams/e9aa899669e6d757f7c368da6a1fde56 to your computer and use it in GitHub Desktop.
Jupyter notebook to convert .safetensors to diffusers
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "ddb4c3fd",
"metadata": {},
"source": [
"## Convert .safetensors to Diffusers"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "08597be1",
"metadata": {},
"outputs": [],
"source": [
"import argparse\n",
"import importlib\n",
"\n",
"import torch\n",
"from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt\n",
"\n",
"\n",
"class Args:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "599eb851",
"metadata": {},
"outputs": [],
"source": [
"args = Args()\n",
"\n",
"args.checkpoint_path = r\"D:\\AI\\automatic\\models\\Stable-diffusion\\realisticVisionV51.safetensors\"\n",
"args.original_config_file = r\"D:\\AI\\automatic\\configs\\v1-inference.yaml\"\n",
"args.dump_path= r\"D:\\AI\\invokeai\\models\\realisticVisionv51\"\n",
"args.image_size = None\n",
"args.prediction_type=None\n",
"args.pipeline_type=None\n",
"args.extract_ema=True\n",
"args.scheduler_type=\"ddim\"\n",
"args.num_in_channels=None\n",
"args.upcast_attention=True\n",
"args.from_safetensors=True\n",
"args.to_safetensors=True\n",
"args.device=\"cuda:0\"\n",
"args.stable_unclip=None\n",
"args.stable_unclip_prior=None\n",
"args.clip_stats_path=None\n",
"args.controlnet=False\n",
"args.vae_path=None\n",
"args.pipeline_class_name = None\n",
"args.half = True\n",
"\n",
"\n",
"if args.pipeline_class_name:\n",
" library = importlib.import_module(\"diffusers\")\n",
" # Show all pipelines\n",
" for p in dir(library):\n",
" print(p) if \"Pipeline\" in p else None\n",
" class_obj = getattr(library, args.pipeline_class_name)\n",
" pipeline_class = class_obj\n",
"else:\n",
" pipeline_class = None\n",
"\n",
"# pipeline_class=library.StableDiffusionPipeline\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "087c4b1e",
"metadata": {},
"outputs": [],
"source": [
"pipe = download_from_original_stable_diffusion_ckpt(\n",
" checkpoint_path=args.checkpoint_path,\n",
" original_config_file=args.original_config_file,\n",
" image_size=args.image_size,\n",
" prediction_type=args.prediction_type,\n",
" model_type=args.pipeline_type,\n",
" extract_ema=args.extract_ema,\n",
" scheduler_type=args.scheduler_type,\n",
" num_in_channels=args.num_in_channels,\n",
" upcast_attention=args.upcast_attention,\n",
" from_safetensors=args.from_safetensors,\n",
" device=args.device,\n",
" stable_unclip=args.stable_unclip,\n",
" stable_unclip_prior=args.stable_unclip_prior,\n",
" clip_stats_path=args.clip_stats_path,\n",
" controlnet=args.controlnet,\n",
" vae_path=args.vae_path,\n",
" pipeline_class=pipeline_class,\n",
")\n",
"\n",
"if args.half:\n",
" pipe.to(torch_dtype=torch.float16)\n",
"\n",
"if args.controlnet:\n",
" # only save the controlnet model\n",
" pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)\n",
"else:\n",
" pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)\n"
]
},
{
"cell_type": "markdown",
"id": "7308aa50",
"metadata": {},
"source": [
"### Check available pipelines"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc2e94c5",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# Show all pipelines\n",
"library = importlib.import_module(\"diffusers\")\n",
"\n",
"for p in dir(library):\n",
" print(p) if \"Pipeline\" in p else None\n"
]
},
{
"cell_type": "markdown",
"id": "a0065561",
"metadata": {},
"source": [
"## Convert VAE"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d30b82d9",
"metadata": {},
"outputs": [],
"source": [
"import argparse\n",
"import io\n",
"\n",
"import requests\n",
"import torch\n",
"from omegaconf import OmegaConf\n",
"\n",
"from diffusers import AutoencoderKL\n",
"from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (\n",
" assign_to_checkpoint,\n",
" conv_attn_to_linear,\n",
" create_vae_diffusers_config,\n",
" renew_vae_attention_paths,\n",
" renew_vae_resnet_paths,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "27fcefd6",
"metadata": {},
"outputs": [],
"source": [
"# Function definitions\n",
"\n",
"def custom_convert_ldm_vae_checkpoint(checkpoint, config):\n",
" vae_state_dict = checkpoint\n",
"\n",
" new_checkpoint = {}\n",
"\n",
" new_checkpoint[\"encoder.conv_in.weight\"] = vae_state_dict[\"encoder.conv_in.weight\"]\n",
" new_checkpoint[\"encoder.conv_in.bias\"] = vae_state_dict[\"encoder.conv_in.bias\"]\n",
" new_checkpoint[\"encoder.conv_out.weight\"] = vae_state_dict[\"encoder.conv_out.weight\"]\n",
" new_checkpoint[\"encoder.conv_out.bias\"] = vae_state_dict[\"encoder.conv_out.bias\"]\n",
" new_checkpoint[\"encoder.conv_norm_out.weight\"] = vae_state_dict[\"encoder.norm_out.weight\"]\n",
" new_checkpoint[\"encoder.conv_norm_out.bias\"] = vae_state_dict[\"encoder.norm_out.bias\"]\n",
"\n",
" new_checkpoint[\"decoder.conv_in.weight\"] = vae_state_dict[\"decoder.conv_in.weight\"]\n",
" new_checkpoint[\"decoder.conv_in.bias\"] = vae_state_dict[\"decoder.conv_in.bias\"]\n",
" new_checkpoint[\"decoder.conv_out.weight\"] = vae_state_dict[\"decoder.conv_out.weight\"]\n",
" new_checkpoint[\"decoder.conv_out.bias\"] = vae_state_dict[\"decoder.conv_out.bias\"]\n",
" new_checkpoint[\"decoder.conv_norm_out.weight\"] = vae_state_dict[\"decoder.norm_out.weight\"]\n",
" new_checkpoint[\"decoder.conv_norm_out.bias\"] = vae_state_dict[\"decoder.norm_out.bias\"]\n",
"\n",
" new_checkpoint[\"quant_conv.weight\"] = vae_state_dict[\"quant_conv.weight\"]\n",
" new_checkpoint[\"quant_conv.bias\"] = vae_state_dict[\"quant_conv.bias\"]\n",
" new_checkpoint[\"post_quant_conv.weight\"] = vae_state_dict[\"post_quant_conv.weight\"]\n",
" new_checkpoint[\"post_quant_conv.bias\"] = vae_state_dict[\"post_quant_conv.bias\"]\n",
"\n",
" # Retrieves the keys for the encoder down blocks only\n",
" num_down_blocks = len({\".\".join(layer.split(\".\")[:3]) for layer in vae_state_dict if \"encoder.down\" in layer})\n",
" down_blocks = {\n",
" layer_id: [key for key in vae_state_dict if f\"down.{layer_id}\" in key] for layer_id in range(num_down_blocks)\n",
" }\n",
"\n",
" # Retrieves the keys for the decoder up blocks only\n",
" num_up_blocks = len({\".\".join(layer.split(\".\")[:3]) for layer in vae_state_dict if \"decoder.up\" in layer})\n",
" up_blocks = {\n",
" layer_id: [key for key in vae_state_dict if f\"up.{layer_id}\" in key] for layer_id in range(num_up_blocks)\n",
" }\n",
"\n",
" for i in range(num_down_blocks):\n",
" resnets = [key for key in down_blocks[i] if f\"down.{i}\" in key and f\"down.{i}.downsample\" not in key]\n",
"\n",
" if f\"encoder.down.{i}.downsample.conv.weight\" in vae_state_dict:\n",
" new_checkpoint[f\"encoder.down_blocks.{i}.downsamplers.0.conv.weight\"] = vae_state_dict.pop(\n",
" f\"encoder.down.{i}.downsample.conv.weight\"\n",
" )\n",
" new_checkpoint[f\"encoder.down_blocks.{i}.downsamplers.0.conv.bias\"] = vae_state_dict.pop(\n",
" f\"encoder.down.{i}.downsample.conv.bias\"\n",
" )\n",
"\n",
" paths = renew_vae_resnet_paths(resnets)\n",
" meta_path = {\"old\": f\"down.{i}.block\", \"new\": f\"down_blocks.{i}.resnets\"}\n",
" assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n",
"\n",
" mid_resnets = [key for key in vae_state_dict if \"encoder.mid.block\" in key]\n",
" num_mid_res_blocks = 2\n",
" for i in range(1, num_mid_res_blocks + 1):\n",
" resnets = [key for key in mid_resnets if f\"encoder.mid.block_{i}\" in key]\n",
"\n",
" paths = renew_vae_resnet_paths(resnets)\n",
" meta_path = {\"old\": f\"mid.block_{i}\", \"new\": f\"mid_block.resnets.{i - 1}\"}\n",
" assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n",
"\n",
" mid_attentions = [key for key in vae_state_dict if \"encoder.mid.attn\" in key]\n",
" paths = renew_vae_attention_paths(mid_attentions)\n",
" meta_path = {\"old\": \"mid.attn_1\", \"new\": \"mid_block.attentions.0\"}\n",
" assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n",
" conv_attn_to_linear(new_checkpoint)\n",
"\n",
" for i in range(num_up_blocks):\n",
" block_id = num_up_blocks - 1 - i\n",
" resnets = [\n",
" key for key in up_blocks[block_id] if f\"up.{block_id}\" in key and f\"up.{block_id}.upsample\" not in key\n",
" ]\n",
"\n",
" if f\"decoder.up.{block_id}.upsample.conv.weight\" in vae_state_dict:\n",
" new_checkpoint[f\"decoder.up_blocks.{i}.upsamplers.0.conv.weight\"] = vae_state_dict[\n",
" f\"decoder.up.{block_id}.upsample.conv.weight\"\n",
" ]\n",
" new_checkpoint[f\"decoder.up_blocks.{i}.upsamplers.0.conv.bias\"] = vae_state_dict[\n",
" f\"decoder.up.{block_id}.upsample.conv.bias\"\n",
" ]\n",
"\n",
" paths = renew_vae_resnet_paths(resnets)\n",
" meta_path = {\"old\": f\"up.{block_id}.block\", \"new\": f\"up_blocks.{i}.resnets\"}\n",
" assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n",
"\n",
" mid_resnets = [key for key in vae_state_dict if \"decoder.mid.block\" in key]\n",
" num_mid_res_blocks = 2\n",
" for i in range(1, num_mid_res_blocks + 1):\n",
" resnets = [key for key in mid_resnets if f\"decoder.mid.block_{i}\" in key]\n",
"\n",
" paths = renew_vae_resnet_paths(resnets)\n",
" meta_path = {\"old\": f\"mid.block_{i}\", \"new\": f\"mid_block.resnets.{i - 1}\"}\n",
" assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n",
"\n",
" mid_attentions = [key for key in vae_state_dict if \"decoder.mid.attn\" in key]\n",
" paths = renew_vae_attention_paths(mid_attentions)\n",
" meta_path = {\"old\": \"mid.attn_1\", \"new\": \"mid_block.attentions.0\"}\n",
" assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n",
" conv_attn_to_linear(new_checkpoint)\n",
" return new_checkpoint\n",
"\n",
"\n",
"def vae_pt_to_vae_diffuser(\n",
" checkpoint_path: str,\n",
" output_path: str,\n",
"):\n",
" # Only support V1\n",
" r = requests.get(\n",
" \" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml\"\n",
" )\n",
" io_obj = io.BytesIO(r.content)\n",
"\n",
" original_config = OmegaConf.load(io_obj)\n",
" image_size = 512\n",
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
" if checkpoint_path.endswith(\"safetensors\"):\n",
" from safetensors import safe_open\n",
"\n",
" checkpoint = {}\n",
" with safe_open(checkpoint_path, framework=\"pt\", device=\"cpu\") as f:\n",
" for key in f.keys():\n",
" checkpoint[key] = f.get_tensor(key)\n",
" else:\n",
" checkpoint = torch.load(checkpoint_path, map_location=device)[\"state_dict\"]\n",
"\n",
" # Convert the VAE model.\n",
" vae_config = create_vae_diffusers_config(original_config, image_size=image_size)\n",
" converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config)\n",
"\n",
" vae = AutoencoderKL(**vae_config)\n",
" vae.load_state_dict(converted_vae_checkpoint)\n",
" vae.save_pretrained(output_path)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "bb30dc66",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"vae_pt_path = r\"D:\\AI\\automatic\\models\\VAE\\vae-ft-ema-560000-ema-pruned.safetensors\"\n",
"dump_path = r\"D:\\AI\\invokeai\\models\\vae-ft-ema-560000-ema-pruned\"\n",
"\n",
"vae_pt_to_vae_diffuser(vae_pt_path, dump_path)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment