Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created April 3, 2023 04:45
Show Gist options
  • Save sayakpaul/638d8e1222ada2440897f6b44b4c54de to your computer and use it in GitHub Desktop.
Save sayakpaul/638d8e1222ada2440897f6b44b4c54de to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "7d7c7c97",
"metadata": {},
"source": [
"## Initialization"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "66b41c23",
"metadata": {},
"outputs": [],
"source": [
"from diffusers.loaders import LoraLoaderMixin\n",
"from diffusers import UNet2DConditionModel\n",
"from transformers import CLIPTextModel\n",
"\n",
"def get_text_encoder():\n",
" return CLIPTextModel.from_pretrained(\n",
" \"runwayml/stable-diffusion-v1-5\", subfolder=\"text_encoder\"\n",
" )\n",
"\n",
"def get_unet():\n",
" return UNet2DConditionModel.from_pretrained(\n",
" \"runwayml/stable-diffusion-v1-5\", subfolder=\"unet\"\n",
" )\n",
"\n",
"text_encoder = get_text_encoder()\n",
"unet = get_unet()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "10ee7af2",
"metadata": {},
"outputs": [],
"source": [
"# UNet LoRA layers. \n",
"from diffusers.loaders import AttnProcsLayers\n",
"from diffusers.models.attention_processor import LoRAAttnProcessor\n",
"\n",
"lora_attn_procs = {}\n",
"for name in unet.attn_processors.keys():\n",
" cross_attention_dim = (\n",
" None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n",
" )\n",
" if name.startswith(\"mid_block\"):\n",
" hidden_size = unet.config.block_out_channels[-1]\n",
" elif name.startswith(\"up_blocks\"):\n",
" block_id = int(name[len(\"up_blocks.\")])\n",
" hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n",
" elif name.startswith(\"down_blocks\"):\n",
" block_id = int(name[len(\"down_blocks.\")])\n",
" hidden_size = unet.config.block_out_channels[block_id]\n",
"\n",
" lora_attn_procs[name] = LoRAAttnProcessor(\n",
" hidden_size=hidden_size, cross_attention_dim=cross_attention_dim\n",
" )\n",
" \n",
"unet.set_attn_processor(lora_attn_procs)\n",
"lora_layers = AttnProcsLayers(unet.attn_processors)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0ecc82cc",
"metadata": {},
"outputs": [],
"source": [
"# Text encoder LoRA layers.\n",
"from diffusers.utils import TEXT_ENCODER_TARGET_MODULES\n",
"\n",
"text_lora_attn_procs = {}\n",
"for name, module in text_encoder.named_modules():\n",
" if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]):\n",
" text_lora_attn_procs[name] = LoRAAttnProcessor(\n",
" hidden_size=module.out_features, cross_attention_dim=None\n",
" )\n",
"\n",
"text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)"
]
},
{
"cell_type": "markdown",
"id": "ae763f96",
"metadata": {},
"source": [
"## Perform optimization (training of the LoRA layers)"
]
},
{
"cell_type": "markdown",
"id": "5b7a060c",
"metadata": {},
"source": [
"## Load into pipeline"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5cbda038",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/sayakpaul/.local/bin/.virtualenvs/diffusers-dev/lib/python3.8/site-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from diffusers import StableDiffusionPipeline\n",
"\n",
"model_id = \"runwayml/stable-diffusion-v1-5\"\n",
"pipeline = StableDiffusionPipeline.from_pretrained(model_id)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3ffc063b",
"metadata": {},
"outputs": [],
"source": [
"# Local works.\n",
"pipeline.load_lora_weights(\".\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d748e196",
"metadata": {},
"outputs": [],
"source": [
"# Remote also works. \n",
"pipeline.load_lora_weights(\"sayakpaul/test-lora-diffusers\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "d84471d2",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b3b4aa286f924a3e8f03d49e9116054b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)h_lora_weights.bin\";: 0%| | 0.00/3.29M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"You have likely saved the LoRA weights using the old format. This will be deprecated soon. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `{new_dictionary.update('unet': old_dictionary)}`\n"
]
}
],
"source": [
"# Old format also works.\n",
"pipeline.load_lora_weights(\"patrickvonplaten/lora_dreambooth_dog_example\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2c0c4561",
"metadata": {},
"outputs": [],
"source": [
"# Doing this also works :)\n",
"pipeline.unet.load_attn_procs(\"patrickvonplaten/lora_dreambooth_dog_example\")"
]
},
{
"cell_type": "markdown",
"id": "bd2b3645",
"metadata": {},
"source": [
"## Serialization semantics"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "255e8579",
"metadata": {},
"outputs": [],
"source": [
"LoraLoaderMixin.save_lora_weights(\n",
" save_directory=\".\",\n",
" unet_lora_layers=lora_layers,\n",
" text_encoder_lora_layers=text_encoder_lora_layers\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment