Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created April 13, 2023 07:16
Show Gist options
  • Save sayakpaul/fd3c0d826705f0f412d6a19164d2ab1a to your computer and use it in GitHub Desktop.
Save sayakpaul/fd3c0d826705f0f412d6a19164d2ab1a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ef865d1c",
"metadata": {},
"outputs": [],
"source": [
"from diffusers.models.attention_processor import (\n",
" CustomDiffusionAttnProcessor,\n",
" AttnProcessor,\n",
")\n",
"from diffusers.loaders import AttnProcsLayers\n",
"from diffusers import UNet2DConditionModel"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3effa492",
"metadata": {},
"outputs": [],
"source": [
"unet = UNet2DConditionModel.from_pretrained(\n",
" \"runwayml/stable-diffusion-v1-5\", subfolder=\"unet\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "1e898585",
"metadata": {},
"source": [
"## Initialize the custom attention processor"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "67f88165",
"metadata": {},
"outputs": [],
"source": [
"# Safe to assume that there will be always cross attention dim for Custom Diffusion.\n",
"# The code snippet below is taken from `examples/dreambooth/train_dreambooth_lora.py`.\n",
"\n",
"custom_diffusion_attn_procs = {}\n",
"for name in unet.attn_processors.keys():\n",
" if \"attn2\" in name:\n",
" cross_attention_dim = unet.config.cross_attention_dim\n",
" if name.startswith(\"mid_block\"):\n",
" hidden_size = unet.config.block_out_channels[-1]\n",
" elif name.startswith(\"up_blocks\"):\n",
" block_id = int(name[len(\"up_blocks.\")])\n",
" hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n",
" elif name.startswith(\"down_blocks\"):\n",
" block_id = int(name[len(\"down_blocks.\")])\n",
" hidden_size = unet.config.block_out_channels[block_id]\n",
"\n",
" custom_diffusion_attn_procs[name] = CustomDiffusionAttnProcessor(\n",
" hidden_dim=hidden_size, cross_attention_dim=cross_attention_dim\n",
" )\n",
" else:\n",
" custom_diffusion_attn_procs[name] = AttnProcessor()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "aea77274",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"32"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(custom_diffusion_attn_procs)"
]
},
{
"cell_type": "markdown",
"id": "28d08480",
"metadata": {},
"source": [
"## Set the processor"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9913eedf",
"metadata": {},
"outputs": [],
"source": [
"import copy \n",
"\n",
"# Since keys get popped up with `set_attn_processor`.\n",
"copy_custom_diffusion_attn_procs = copy.deepcopy(custom_diffusion_attn_procs)\n",
"unet.set_attn_processor(custom_diffusion_attn_procs)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "06c8b565",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0, 32)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(custom_diffusion_attn_procs), len(copy_custom_diffusion_attn_procs)"
]
},
{
"cell_type": "markdown",
"id": "a0ed15a0",
"metadata": {},
"source": [
"## Serialization mechanics"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "100107ea",
"metadata": {},
"outputs": [],
"source": [
"non_attn1_procs = {}\n",
"for k in copy_custom_diffusion_attn_procs:\n",
" if \"attn2\" in k:\n",
" non_attn1_procs.update({k: copy_custom_diffusion_attn_procs[k]})\n",
" \n",
"custom_diffusion_layers = AttnProcsLayers(non_attn1_procs)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "38e6b38b",
"metadata": {},
"outputs": [],
"source": [
"import torch \n",
"\n",
"torch.save(custom_diffusion_layers.state_dict(), \"custom_diffusion_layers.bin\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a0828ce9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-rw-r--r-- 1 sayakpaul staff 110M Apr 13 12:43 custom_diffusion_layers.bin\r\n"
]
}
],
"source": [
"!ls -lh custom_diffusion_layers.bin"
]
},
{
"cell_type": "markdown",
"id": "7233cf5e",
"metadata": {},
"source": [
"## Loading"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f2305dd8",
"metadata": {},
"outputs": [],
"source": [
"# Needs to be implemented."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment