Created
April 13, 2023 07:16
-
-
Save sayakpaul/fd3c0d826705f0f412d6a19164d2ab1a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "ef865d1c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from diffusers.models.attention_processor import (\n", | |
" CustomDiffusionAttnProcessor,\n", | |
" AttnProcessor,\n", | |
")\n", | |
"from diffusers.loaders import AttnProcsLayers\n", | |
"from diffusers import UNet2DConditionModel" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "3effa492", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"unet = UNet2DConditionModel.from_pretrained(\n", | |
" \"runwayml/stable-diffusion-v1-5\", subfolder=\"unet\"\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1e898585", | |
"metadata": {}, | |
"source": [ | |
"## Initialize the custom attention processor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "67f88165", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Safe to assume that there will be always cross attention dim for Custom Diffusion.\n", | |
"# The code snippet below is taken from `examples/dreambooth/train_dreambooth_lora.py`.\n", | |
"\n", | |
"custom_diffusion_attn_procs = {}\n", | |
"for name in unet.attn_processors.keys():\n", | |
" if \"attn2\" in name:\n", | |
" cross_attention_dim = unet.config.cross_attention_dim\n", | |
" if name.startswith(\"mid_block\"):\n", | |
" hidden_size = unet.config.block_out_channels[-1]\n", | |
" elif name.startswith(\"up_blocks\"):\n", | |
" block_id = int(name[len(\"up_blocks.\")])\n", | |
" hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n", | |
" elif name.startswith(\"down_blocks\"):\n", | |
" block_id = int(name[len(\"down_blocks.\")])\n", | |
" hidden_size = unet.config.block_out_channels[block_id]\n", | |
"\n", | |
" custom_diffusion_attn_procs[name] = CustomDiffusionAttnProcessor(\n", | |
" hidden_dim=hidden_size, cross_attention_dim=cross_attention_dim\n", | |
" )\n", | |
" else:\n", | |
" custom_diffusion_attn_procs[name] = AttnProcessor()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "aea77274", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"32" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(custom_diffusion_attn_procs)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "28d08480", | |
"metadata": {}, | |
"source": [ | |
"## Set the processor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "9913eedf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import copy \n", | |
"\n", | |
"# Since keys get popped up with `set_attn_processor`.\n", | |
"copy_custom_diffusion_attn_procs = copy.deepcopy(custom_diffusion_attn_procs)\n", | |
"unet.set_attn_processor(custom_diffusion_attn_procs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "06c8b565", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(0, 32)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(custom_diffusion_attn_procs), len(copy_custom_diffusion_attn_procs)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a0ed15a0", | |
"metadata": {}, | |
"source": [ | |
"## Serialization mechanics" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "100107ea", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"non_attn1_procs = {}\n", | |
"for k in copy_custom_diffusion_attn_procs:\n", | |
" if \"attn2\" in k:\n", | |
" non_attn1_procs.update({k: copy_custom_diffusion_attn_procs[k]})\n", | |
" \n", | |
"custom_diffusion_layers = AttnProcsLayers(non_attn1_procs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "38e6b38b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch \n", | |
"\n", | |
"torch.save(custom_diffusion_layers.state_dict(), \"custom_diffusion_layers.bin\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "a0828ce9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"-rw-r--r-- 1 sayakpaul staff 110M Apr 13 12:43 custom_diffusion_layers.bin\r\n" | |
] | |
} | |
], | |
"source": [ | |
"!ls -lh custom_diffusion_layers.bin" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7233cf5e", | |
"metadata": {}, | |
"source": [ | |
"## Loading" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "f2305dd8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Needs to be implemented." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment