Skip to content

Instantly share code, notes, and snippets.

@takuma104
Last active May 18, 2023 16:02
Show Gist options
  • Save takuma104/3365461ee92d7fe6489598278e293239 to your computer and use it in GitHub Desktop.
Save takuma104/3365461ee92d7fe6489598278e293239 to your computer and use it in GitHub Desktop.
scratchpad
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/takuma104/3365461ee92d7fe6489598278e293239/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install transformers accelerate -q\n",
"!pip install git+https://github.com/takuma104/diffusers@kohya-lora-loader -q"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QkFfAW1xrmZ8",
"outputId": "db009dab-fe4b-4ddd-8228-42b7df265999"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import torch \n",
"\n",
"torch.manual_seed(0)\n",
"\n",
"print(torch.__version__)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wxllxzuH5Mdt",
"outputId": "b397288f-a8ea-48f1-fbe0-99a887f61e90"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2.0.0+cu118\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!nvidia-smi"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "P0x681nm5Ono",
"outputId": "66ac77ef-1107-47bb-9370-1385be2169e5"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Thu May 18 15:48:02 2023 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 66C P8 10W / 70W | 0MiB / 15360MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from diffusers.models.attention_processor import LoRAAttnProcessor\n",
"from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin\n",
"\n",
"def get_lora_layers(text_encoder):\n",
" text_lora_attn_procs = {}\n",
" for name, module in text_encoder.named_modules():\n",
" # if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):\n",
" if name.endswith(\"self_attn\"):\n",
" # print(name)\n",
" text_lora_attn_procs[name] = LoRAAttnProcessor(\n",
" hidden_size=module.out_proj.out_features, cross_attention_dim=None\n",
" )\n",
" text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)\n",
" return text_lora_attn_procs, text_encoder_lora_layers"
],
"metadata": {
"id": "CP1xF4oAtRE_"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from collections import defaultdict\n",
"\n",
"def load_lora_attn_procs_te(text_encoder, state_dict, network_alpha=None, device=\"cuda\"):\n",
" attn_processors = {}\n",
" is_lora = all(\"lora\" in k for k in state_dict.keys())\n",
"\n",
" if is_lora:\n",
" lora_grouped_dict = defaultdict(dict)\n",
" for key, value in state_dict.items():\n",
" attn_processor_key, sub_key = \".\".join(key.split(\".\")[:-3]), \".\".join(key.split(\".\")[-3:])\n",
" lora_grouped_dict[attn_processor_key][sub_key] = value\n",
"\n",
" for key, value_dict in lora_grouped_dict.items():\n",
" rank = value_dict[\"to_k_lora.down.weight\"].shape[0]\n",
" cross_attention_dim = value_dict[\"to_k_lora.down.weight\"].shape[1]\n",
" hidden_size = value_dict[\"to_k_lora.up.weight\"].shape[0]\n",
"\n",
" attn_processors[key] = LoRAAttnProcessor(\n",
" hidden_size=hidden_size,\n",
" cross_attention_dim=cross_attention_dim,\n",
" rank=rank,\n",
" network_alpha=network_alpha,\n",
" )\n",
" attn_processors[key].load_state_dict(value_dict)\n",
"\n",
" # else:\n",
" # raise ValueError(f\"{model_file} does not seem to be in the correct format expected by LoRA training.\")\n",
"\n",
" # set correct dtype & device\n",
" attn_processors = {\n",
" k: v.to(device=device, dtype=text_encoder.dtype) for k, v in attn_processors.items()\n",
" }\n",
" return attn_processors"
],
"metadata": {
"id": "LJeTzlwM5Yt8"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from diffusers.utils import TEXT_ENCODER_TARGET_MODULES\n",
"\n",
"\n",
"def get_lora_layer_attribute(name: str) -> str:\n",
" if \"q_proj\" in name:\n",
" return \"to_q_lora\"\n",
" elif \"v_proj\" in name:\n",
" return \"to_v_lora\"\n",
" elif \"k_proj\" in name:\n",
" return \"to_k_lora\"\n",
" else:\n",
" return \"to_out_lora\"\n",
"\n",
"\n",
"def modify_text_encoder_instance(text_encoder, attn_processors, inputs):\n",
" outputs = text_encoder(**inputs)[0]\n",
" print(\"Original\")\n",
" print(f\"Sample outputs: {outputs[:, :5, :5].flatten()}\")\n",
" print(f\"Mean: {outputs.mean()}, Std: {outputs.std()}\")\n",
" print(\"=\"*80)\n",
" \n",
" # Loop over the original attention modules.\n",
" for name, _ in text_encoder.named_modules():\n",
" if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):\n",
" # Retrieve the module and its corresponding LoRA processor.\n",
" module = text_encoder.get_submodule(name)\n",
" # Construct a new function that performs the LoRA merging. We will monkey patch\n",
" # this forward pass.\n",
" attn_processor_name = '.'.join(name.split('.')[:-1]) # modified\n",
" if attn_processor_name in attn_processors: # modified\n",
" print(\"Entered in the monkey patching loop\")\n",
" module.lora_layer = getattr(attn_processors[attn_processor_name], get_lora_layer_attribute(name)) # modified\n",
" module.old_forward = module.forward\n",
"\n",
" def new_forward(self, x):\n",
" # return self.old_forward(x) * 2.0\n",
" return self.old_forward(x) + self.lora_layer(x)\n",
"\n",
" # Monkey-patch.\n",
" module.forward = new_forward.__get__(module)\n",
"\n",
" print(\"=\"*80)\n",
" print(\"LoRA\")\n",
" outputs = text_encoder(**inputs)[0]\n",
" print(f\"Sample outputs: {outputs[:, :5, :5].flatten()}\")\n",
" print(f\"Mean: {outputs.mean()}, Std: {outputs.std()}\")"
],
"metadata": {
"id": "ELB_6wJk6A_Z"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def modify_text_encoder(text_encoder, attn_processors, inputs):\n",
" outputs = text_encoder(**inputs)[0]\n",
" print(\"Original\")\n",
" print(f\"Sample outputs: {outputs[:, :5, :5].flatten()}\")\n",
" print(f\"Mean: {outputs.mean()}, Std: {outputs.std()}\")\n",
" print(\"=\"*80)\n",
" # Loop over the original attention modules.\n",
" for name, _ in text_encoder.named_modules():\n",
" if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):\n",
" # Retrieve the module and its corresponding LoRA processor.\n",
" module = text_encoder.get_submodule(name)\n",
" # Construct a new function that performs the LoRA merging. We will monkey patch\n",
" # this forward pass.\n",
" attn_processor_name = '.'.join(name.split('.')[:-1]) # modified\n",
" if attn_processor_name in attn_processors: # modified\n",
" print(\"Entered in the monkey patching loop\")\n",
" lora_layer = getattr(attn_processors[attn_processor_name], get_lora_layer_attribute(name)) # modified\n",
" old_forward = module.forward\n",
"\n",
" def new_forward(x):\n",
" return old_forward(x) + lora_layer(x)\n",
"\n",
" # Monkey-patch.\n",
" module.forward = new_forward\n",
"\n",
" print(\"=\"*80)\n",
" print(\"LoRA\")\n",
" outputs = text_encoder(**inputs)[0]\n",
" print(f\"Sample outputs: {outputs[:, :5, :5].flatten()}\")\n",
" print(f\"Mean: {outputs.mean()}, Std: {outputs.std()}\")"
],
"metadata": {
"id": "ZtcsqHKetwrq"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"batch_size = 1\n",
"max_seq_length = 77\n",
"\n",
"inputs = torch.randint(\n",
" 2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)\n",
").to(\"cuda\")\n",
"\n",
"prepared_inputs = {}\n",
"prepared_inputs[\"input_ids\"] = inputs\n",
"prepared_inputs"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "37X0RZ9X85Qf",
"outputId": "c781245b-cf47-4a06-e4ae-a58eeb444fcf"
},
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'input_ids': tensor([[28, 11, 49, 26, 27, 53, 27, 39, 39, 11, 55, 22, 8, 41, 12, 40, 20, 38,\n",
" 16, 23, 46, 13, 37, 22, 28, 43, 9, 37, 35, 38, 21, 2, 19, 44, 3, 17,\n",
" 30, 55, 19, 54, 40, 23, 18, 9, 39, 25, 15, 27, 34, 27, 6, 13, 43, 51,\n",
" 37, 16, 43, 6, 4, 31, 15, 42, 29, 8, 51, 34, 42, 54, 28, 42, 37, 31,\n",
" 26, 36, 38, 7, 5]], device='cuda:0')}"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"source": [
"from transformers import CLIPTextModel\n",
"\n",
"# Recreating the flow from diffusers as closely as possible.\n",
"def apply_and_test(inputs, instance=True):\n",
" # Load text encoder.\n",
" text_encoder = CLIPTextModel.from_pretrained(\n",
" \"runwayml/stable-diffusion-v1-5\", subfolder=\"text_encoder\"\n",
" ).to(\"cuda\")\n",
"\n",
" # Initialize LoRA layers and determine the attention processors.\n",
" _, text_encoder_lora_layers = get_lora_layers(text_encoder)\n",
" attn_procs = load_lora_attn_procs_te(\n",
" text_encoder, text_encoder_lora_layers.state_dict()\n",
" )\n",
"\n",
" # Apply pseudo-LoRA.\n",
" if instance:\n",
" modify_text_encoder_instance(text_encoder, attn_procs, inputs)\n",
" else:\n",
" modify_text_encoder(text_encoder, attn_procs, inputs)"
],
"metadata": {
"id": "3dI5AuLx9UMD"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Instance case. \n",
"apply_and_test(prepared_inputs)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "o_Vmoq3h_3OE",
"outputId": "80d437b0-61d3-4314-b1ae-9516af749be7"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Original\n",
"Sample outputs: tensor([-0.3475, 0.0037, -0.0443, -0.1530, -0.0716, 0.8796, -1.9955, 1.3066,\n",
" 0.4707, -0.0186, -0.0431, -1.5263, 1.8151, 1.4243, 2.1641, -0.8808,\n",
" 0.1435, -0.7003, 0.9275, 1.0764, 0.1422, 0.7853, 0.5531, 0.9211,\n",
" 0.5118], device='cuda:0', grad_fn=<UnsafeViewBackward0>)\n",
"Mean: -0.10286597162485123, Std: 1.0281846523284912\n",
"================================================================================\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"================================================================================\n",
"LoRA\n",
"Sample outputs: tensor([-0.3475, 0.0037, -0.0443, -0.1530, -0.0716, 0.8796, -1.9955, 1.3066,\n",
" 0.4707, -0.0186, -0.0431, -1.5263, 1.8151, 1.4243, 2.1641, -0.8808,\n",
" 0.1435, -0.7003, 0.9275, 1.0764, 0.1422, 0.7853, 0.5531, 0.9211,\n",
" 0.5118], device='cuda:0', grad_fn=<UnsafeViewBackward0>)\n",
"Mean: -0.10286597162485123, Std: 1.0281846523284912\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Vanilla case. \n",
"apply_and_test(prepared_inputs, instance=False)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jrrutVEu_7O4",
"outputId": "0ea3f7cd-cd7f-4a68-fb36-3b9db90df3e8"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Original\n",
"Sample outputs: tensor([-0.3475, 0.0037, -0.0443, -0.1530, -0.0716, 0.8796, -1.9955, 1.3066,\n",
" 0.4707, -0.0186, -0.0431, -1.5263, 1.8151, 1.4243, 2.1641, -0.8808,\n",
" 0.1435, -0.7003, 0.9275, 1.0764, 0.1422, 0.7853, 0.5531, 0.9211,\n",
" 0.5118], device='cuda:0', grad_fn=<UnsafeViewBackward0>)\n",
"Mean: -0.10286597162485123, Std: 1.0281846523284912\n",
"================================================================================\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"Entered in the monkey patching loop\n",
"================================================================================\n",
"LoRA\n",
"Sample outputs: tensor([-0.3968, 0.6696, -0.7769, -0.0285, -1.0446, -1.5564, 0.8757, 0.0335,\n",
" 0.1260, -0.8673, -1.7689, 0.9730, -0.1255, 0.2257, -0.4124, -1.6684,\n",
" 0.6530, -0.2975, -0.0976, -0.8665, -1.6959, 0.7818, -0.4936, -0.1924,\n",
" -1.1919], device='cuda:0', grad_fn=<UnsafeViewBackward0>)\n",
"Mean: -0.10978495329618454, Std: 1.000146746635437\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Load text encoder.\n",
"text_encoder = CLIPTextModel.from_pretrained(\n",
" \"runwayml/stable-diffusion-v1-5\", subfolder=\"text_encoder\"\n",
").to(\"cuda\")\n",
"\n",
"# Initialize LoRA layers and determine the attention processors.\n",
"_, text_encoder_lora_layers = get_lora_layers(text_encoder)\n",
"attn_procs = load_lora_attn_procs_te(\n",
" text_encoder, text_encoder_lora_layers.state_dict()\n",
")\n",
"print(attn_procs.keys())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VUYG-RpBEJRu",
"outputId": "4b31e0a5-e159-4ff7-cfa8-09507e90d69a"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"dict_keys(['text_model.encoder.layers.0.self_attn', 'text_model.encoder.layers.1.self_attn', 'text_model.encoder.layers.2.self_attn', 'text_model.encoder.layers.3.self_attn', 'text_model.encoder.layers.4.self_attn', 'text_model.encoder.layers.5.self_attn', 'text_model.encoder.layers.6.self_attn', 'text_model.encoder.layers.7.self_attn', 'text_model.encoder.layers.8.self_attn', 'text_model.encoder.layers.9.self_attn', 'text_model.encoder.layers.10.self_attn', 'text_model.encoder.layers.11.self_attn'])\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "xGCENIQJgz7-"
},
"execution_count": 12,
"outputs": []
}
],
"metadata": {
"colab": {
"name": "scratchpad",
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment