Skip to content

Instantly share code, notes, and snippets.

@takuma104
Created May 8, 2023 15:49
Show Gist options
  • Save takuma104/e38d683d72b1e448b8d9b3835f7cfa44 to your computer and use it in GitHub Desktop.
Save takuma104/e38d683d72b1e448b8d9b3835f7cfa44 to your computer and use it in GitHub Desktop.
import math
import safetensors
import torch
from diffusers import DiffusionPipeline
"""
Kohya's LoRA format Loader for Diffusers
Usage:
```py
# An usual Diffusers' setup
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained('...',
torch_dtype=torch.float16).to('cuda')
# Import this module
import kohya_lora_loader
# Install LoRA hook. This append apply_loar and remove_loar methods to the pipe.
kohya_lora_loader.install_lora_hook(pipe)
# Load 'lora1.safetensors' file and apply
lora1 = pipe.apply_lora('lora1.safetensors', 1.0)
# You can change alpha
lora1.alpha = 0.5
# Load 'lora2.safetensors' file and apply
lora2 = pipe.apply_lora('lora2.safetensors', 1.0)
# Generate image with lora1 and lora2 applied
pipe(...).images[0]
# Remove lora2
pipe.remove_lora(lora2)
# Generate image with lora1 applied
pipe(...).images[0]
# Uninstall LoRA hook
kohya_lora_loader.uninstall_lora_hook(pipe)
# Generate image with none LoRA applied
pipe(...).images[0]
```
"""
# modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17
class LoRAModule(torch.nn.Module):
def __init__(
self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
)
self.lora_up = torch.nn.Conv2d(
self.lora_dim, out_dim, (1, 1), (1, 1), bias=False
)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
if alpha is None or alpha == 0:
self.alpha = self.lora_dim
else:
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
self.register_buffer("alpha", torch.tensor(alpha)) # Treatable as a constant.
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier
def forward(self, x):
scale = self.alpha / self.lora_dim
return self.multiplier * scale * self.lora_up(self.lora_down(x))
class LoRAModuleContainer(torch.nn.Module):
def __init__(self, hooks, state_dict, multiplier):
super().__init__()
self.multiplier = multiplier
# Create LoRAModule from state_dict information
for key, value in state_dict.items():
if "lora_down" in key:
lora_name = key.split(".")[0]
lora_dim = value.size()[0]
lora_name_alpha = key.split(".")[0] + '.alpha'
alpha = None
if lora_name_alpha in state_dict:
alpha = state_dict[lora_name_alpha].item()
hook = hooks[lora_name]
lora_module = LoRAModule(
hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier
)
self.register_module(lora_name, lora_module)
# Load whole LoRA weights
self.load_state_dict(state_dict)
# Register LoRAModule to LoRAHook
for name, module in self.named_modules():
if module.__class__.__name__ == "LoRAModule":
hook = hooks[name]
hook.append_lora(module)
@property
def alpha(self):
return self.multiplier
@alpha.setter
def alpha(self, multiplier):
self.multiplier = multiplier
for name, module in self.named_modules():
if module.__class__.__name__ == "LoRAModule":
module.multiplier = multiplier
def remove_from_hooks(self, hooks):
for name, module in self.named_modules():
if module.__class__.__name__ == "LoRAModule":
hook = hooks[name]
hook.remove_lora(module)
del module
class LoRAHook(torch.nn.Module):
"""
replaces forward method of the original Linear,
instead of replacing the original Linear module.
"""
def __init__(self):
super().__init__()
self.lora_modules = []
def install(self, orig_module):
assert not hasattr(self, "orig_module")
self.orig_module = orig_module
self.orig_forward = self.orig_module.forward
self.orig_module.forward = self.forward
def uninstall(self):
assert hasattr(self, "orig_module")
self.orig_module.forward = self.orig_forward
del self.orig_forward
del self.orig_module
def append_lora(self, lora_module):
self.lora_modules.append(lora_module)
def remove_lora(self, lora_module):
self.lora_modules.remove(lora_module)
def forward(self, x):
if len(self.lora_modules) == 0:
return self.orig_forward(x)
lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0)
return self.orig_forward(x) + lora
class LoRAHookInjector(object):
def __init__(self):
super().__init__()
self.hooks = {}
self.device = None
self.dtype = None
def _get_target_modules(self, root_module, prefix, target_replace_modules):
target_modules = []
for name, module in root_module.named_modules():
if (
module.__class__.__name__ in target_replace_modules
and not "transformer_blocks" in name
): # to adapt latest diffusers:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
target_modules.append((lora_name, child_module))
return target_modules
def install_hooks(self, pipe):
"""Install LoRAHook to the pipe."""
assert len(self.hooks) == 0
text_encoder_targets = self._get_target_modules(
pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]
)
unet_targets = self._get_target_modules(
pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]
)
for name, target_module in text_encoder_targets + unet_targets:
hook = LoRAHook()
hook.install(target_module)
self.hooks[name] = hook
self.device = pipe.device
self.dtype = pipe.unet.dtype
def uninstall_hooks(self):
"""Uninstall LoRAHook from the pipe."""
for k, v in self.hooks.items():
v.uninstall()
self.hooks = {}
def apply_lora(self, filename, alpha=1.0):
"""Load LoRA weights and apply LoRA to the pipe."""
assert len(self.hooks) != 0
state_dict = safetensors.torch.load_file(filename)
container = LoRAModuleContainer(self.hooks, state_dict, alpha)
container.to(self.device, self.dtype)
return container
def remove_lora(self, container):
"""Remove the individual LoRA from the pipe."""
container.remove_from_hooks(self.hooks)
def install_lora_hook(pipe: DiffusionPipeline):
"""Install LoRAHook to the pipe."""
assert not hasattr(pipe, "lora_injector")
assert not hasattr(pipe, "apply_lora")
assert not hasattr(pipe, "remove_lora")
injector = LoRAHookInjector()
injector.install_hooks(pipe)
pipe.lora_injector = injector
pipe.apply_lora = injector.apply_lora
pipe.remove_lora = injector.remove_lora
def uninstall_lora_hook(pipe: DiffusionPipeline):
"""Uninstall LoRAHook from the pipe."""
pipe.lora_injector.uninstall_hooks()
del pipe.lora_injector
del pipe.apply_lora
del pipe.remove_lora
@alexblattner
Copy link

alexblattner commented May 16, 2023

Hey, man. I have tried your code with a loaded safetensors from civitai. I got this error:
/mecomics-api/tester.py:36 in │
│ │
│ 33 buffer.seek(0) │
│ 34 image_bytes = buffer.read() │
│ 35 images = Image.open(BytesIO(image_bytes)) │
│ ❱ 36 image = pipe( │
│ 37 │ prompt=None, │
│ 38 │ negative_prompt=None, │
│ 39 │ prompt_embeds=promptE, │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/utils/_contextlib.py:115 in decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /mecomics-api/multiDiffusion.py:1064 in call
│ │
│ 1061 │ │ │ │ │ │ │ guess_mode=False, │
│ 1062 │ │ │ │ │ │ │ return_dict=False, │
│ 1063 │ │ │ │ │ │ ) │
│ ❱ 1064 │ │ │ │ │ │ noise_pred=self.unet( │
│ 1065 │ │ │ │ │ │ │ latent_model_input, │
│ 1066 │ │ │ │ │ │ │ t, │
│ 1067 │ │ │ │ │ │ │ encoder_hidden_states=text_embeddings[i], │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /home/alexblattnershalom/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.p │
│ y:724 in forward │
│ │
│ 721 │ │ down_block_res_samples = (sample,) │
│ 722 │ │ for downsample_block in self.down_blocks: │
│ 723 │ │ │ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has │
│ ❱ 724 │ │ │ │ sample, res_samples = downsample_block( │
│ 725 │ │ │ │ │ hidden_states=sample, │
│ 726 │ │ │ │ │ temb=emb, │
│ 727 │ │ │ │ │ encoder_hidden_states=encoder_hidden_states, │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /home/alexblattnershalom/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py:8 │
│ 68 in forward │
│ │
│ 865 │ │ │ │ )[0] │
│ 866 │ │ │ else: │
│ 867 │ │ │ │ hidden_states = resnet(hidden_states, temb) │
│ ❱ 868 │ │ │ │ hidden_states = attn( │
│ 869 │ │ │ │ │ hidden_states, │
│ 870 │ │ │ │ │ encoder_hidden_states=encoder_hidden_states, │
│ 871 │ │ │ │ │ cross_attention_kwargs=cross_attention_kwargs, │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /home/alexblattnershalom/.local/lib/python3.8/site-packages/diffusers/models/transformer_2d.py:2 │
│ 51 in forward │
│ │
│ 248 │ │ │ │
│ 249 │ │ │ hidden_states = self.norm(hidden_states) │
│ 250 │ │ │ if not self.use_linear_projection: │
│ ❱ 251 │ │ │ │ hidden_states = self.proj_in(hidden_states) │
│ 252 │ │ │ │ inner_dim = hidden_states.shape[1] │
│ 253 │ │ │ │ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height │
│ 254 │ │ │ else: │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /mecomics-api/kohya_lora_loader.py:179 in forward │
│ │
│ 176 │ def forward(self, x): │
│ 177 │ │ if len(self.lora_modules) == 0: │
│ 178 │ │ │ return self.orig_forward(x) │
│ ❱ 179 │ │ lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0) │
│ 180 │ │ return self.orig_forward(x) + lora │
│ 181 │
│ 182 │
│ │
│ /mecomics-api/kohya_lora_loader.py:179 in │
│ │
│ 176 │ def forward(self, x): │
│ 177 │ │ if len(self.lora_modules) == 0: │
│ 178 │ │ │ return self.orig_forward(x) │
│ ❱ 179 │ │ lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0) │
│ 180 │ │ return self.orig_forward(x) + lora │
│ 181 │
│ 182 │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /mecomics-api/kohya_lora_loader.py:98 in forward │
│ │
│ 95 │ │
│ 96 │ def forward(self, x): │
│ 97 │ │ scale = self.alpha / self.lora_dim │
│ ❱ 98 │ │ return self.multiplier * scale * self.lora_up(self.lora_down(x)) │
│ 99 │
│ 100 │
│ 101 class LoRAModuleContainer(torch.nn.Module): │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/conv.py:463 in forward │
│ │
│ 460 │ │ │ │ │ │ self.padding, self.dilation, self.groups) │
│ 461 │ │
│ 462 │ def forward(self, input: Tensor) -> Tensor: │
│ ❱ 463 │ │ return self._conv_forward(input, self.weight, self.bias) │
│ 464 │
│ 465 class Conv3d(_ConvNd): │
│ 466 │ doc = r"""Applies a 3D convolution over an input signal composed of several inpu │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/conv.py:459 in _conv_forward │
│ │
│ 456 │ │ │ return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=sel │
│ 457 │ │ │ │ │ │ │ weight, bias, self.stride, │
│ 458 │ │ │ │ │ │ │ _pair(0), self.dilation, self.groups) │
│ ❱ 459 │ │ return F.conv2d(input, weight, bias, self.stride, │
│ 460 │ │ │ │ │ │ self.padding, self.dilation, self.groups) │
│ 461 │ │
│ 462 │ def forward(self, input: Tensor) -> Tensor: │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

this is how I loaded the ckpt:

pipe = MultiStableDiffusion.from_ckpt(
    "deliberate.safetensors",local_files_only=True, torch_dtype=torch.float16,safety_checker=None,requires_safety_checker=False
).to("cuda")

@takuma104

@alexblattner
Copy link

I found out why it didn't work. It doesn't take into account the model's dtype and is always float32 from what I saw thus far

@adhikjoshi
Copy link

@takuma104 Can you add support for SDXL models?

@alexblattner
Copy link

@takuma104 do you think you'll be able to do Lycoris support?

@pftjoeyyang
Copy link

pftjoeyyang commented Nov 14, 2023

For diffusers after #4147, consider adding the following module checkings if you still want to use this hook:
at line 197

for child_name, child_module in module.named_modules():
    is_linear = child_module.__class__.__name__ in ["Linear", "LoRACompatibleLinear"]
    is_conv2d = child_module.__class__.__name__ in ["Conv2d", "LoRACompatibleConv"]

and at line 60 & 69:

if org_module.__class__.__name__ in ["Conv2d", "LoRACompatibleConv"]:

also, add a default scale in LoRAHook.forward at line 176:

def forward(self, x, scale=1.0):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment