Skip to content

Instantly share code, notes, and snippets.

@takuma104
Last active February 8, 2024 01:42
Show Gist options
  • Save takuma104/43552b8ec70b63323c57dc9c6fcb9b90 to your computer and use it in GitHub Desktop.
Save takuma104/43552b8ec70b63323c57dc9c6fcb9b90 to your computer and use it in GitHub Desktop.
clip_text_custom_embedder

Usage

from clip_text_custom_embedder import text_embeddings
from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to('cuda')

prompt = "((masterpiece, best quality)), white background, close-up, 1girl, litte smile"
negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), "
                   "bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), partial face, "
                   "partial head, cropped head")
cond, uncond = text_embeddings(pipe, prompt, negative_prompt, clip_stop_at_last_layers=2)

images = pipe(prompt_embeds=cond, 
              negative_prompt_embeds=uncond, 
              generator=torch.manual_seed(seed)).images[0]        
import torch
import math
import re
# copied and customized from automatic1111 sd_hijack.py & prompt_parser.py
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/ec1924ee5789b72c31c65932b549c59ccae0cdd6/modules/sd_hijack.py#L113
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/ec1924ee5789b72c31c65932b549c59ccae0cdd6/modules/prompt_parser.py#L259
re_attention = re.compile(r"""
\\\(|
\\\{|
\\\)|
\\\}|
\\\[|
\\]|
\\\\|
\\|
\(|
\{|
\[|
:([+-]?[.\d]+)\)|
\)|
\}|
]|
[^\\()\\{}\[\]:]+|
:
""", re.X)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith('\\'):
res.append([text[1:], 1.0])
elif text == '(' or text == '{':
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif (text == ')' or text == '}') and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
class CLIPTextCustomEmbedder(object):
def __init__(self, tokenizer, text_encoder, device,
clip_stop_at_last_layers=1):
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.token_mults = {}
self.device = device
self.clip_stop_at_last_layers = clip_stop_at_last_layers
def tokenize_line(self, line):
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
id_end = self.tokenizer.eos_token_id
parsed = parse_prompt_attention(line)
tokenized = self.tokenizer(
[text for text, _ in parsed], truncation=False,
add_special_tokens=False)["input_ids"]
fixes = []
remade_tokens = []
multipliers = []
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
remade_tokens.append(token)
multipliers.append(weight)
i += 1
token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * tokens_to_add
multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
def process_text(self, texts):
if isinstance(texts, str):
texts = [texts]
remade_batch_tokens = []
cache = {}
batch_multipliers = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, _ = self.tokenize_line(line)
cache[line] = (remade_tokens, fixes, multipliers)
remade_batch_tokens.append(remade_tokens)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens
def __call__(self, text):
batch_multipliers, remade_batch_tokens = self.process_text(text)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.tokenizer.eos_token_id] * 75)
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
remade_batch_tokens = [[self.tokenizer.bos_token_id] + x[:75] +
[self.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(self.device)
# print(tokens.shape)
# print(tokens)
outputs = self.text_encoder(
input_ids=tokens, output_hidden_states=True)
if self.clip_stop_at_last_layers > 1:
z = self.text_encoder.text_model.final_layer_norm(
outputs.hidden_states[-self.clip_stop_at_last_layers])
else:
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well
# to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [
x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(
batch_multipliers_of_same_length).to(self.device)
# print(batch_multipliers.shape)
# print(batch_multipliers)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape +
(1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z
def get_text_tokens(self, text):
batch_multipliers, remade_batch_tokens = self.process_text(text)
return [[self.tokenizer.bos_token_id] + remade_batch_tokens[0]], \
[[1.0] + batch_multipliers[0]]
def text_embeddings_equal_len(text_embedder, prompt, negative_prompt):
cond_embeddings = text_embedder(prompt)
uncond_embeddings = text_embedder(negative_prompt)
cond_len = cond_embeddings.shape[1]
uncond_len = uncond_embeddings.shape[1]
if cond_len == uncond_len:
return cond_embeddings, uncond_embeddings
else:
if cond_len > uncond_len:
n = (cond_len - uncond_len) // 77
return cond_embeddings, torch.cat([uncond_embeddings] + [text_embedder("")]*n, dim=1)
else:
n = (uncond_len - cond_len) // 77
return torch.cat([cond_embeddings] + [text_embedder("")]*n, dim=1), uncond_embeddings
def text_embeddings(pipe, prompt, negative_prompt, clip_stop_at_last_layers=1):
text_embedder = CLIPTextCustomEmbedder(tokenizer=pipe.tokenizer,
text_encoder=pipe.text_encoder,
device=pipe.text_encoder.device,
clip_stop_at_last_layers=clip_stop_at_last_layers)
cond_embeddings, uncond_embeddings = text_embeddings_equal_len(text_embedder, prompt, negative_prompt)
return cond_embeddings, uncond_embeddings
@takuma104
Copy link
Author

takuma104 commented May 8, 2023

It seems that the development of the train script (Kohya's implementation) has progressed faster, and the LoRA files created using it have rapidly become widespread in CivitAI, making it difficult for Diffusers' LoRA implementation to catch up. The script I wrote uses hooks, so it's not likely to be merged into Diffusers. I'm considering creating a non-hook version as well.

Since most of the LoRA files in CivitAI are applied to both TextEncoder and Unet, it might be quite complicated, and you might need to redo the TextEncoder processing in the denoising loop. If you want to keep it simple, you could create a code that only expects changes to the Unet, as at least those should be effective. For example, it should be possible to change the LoRA state around here (before the Unet inference) in the StableDiffusionPipeline. You should be able to use apply_lora/remove_lora without any restrictions, but since it's not very fast, I recommend dynamically adjusting the alpha values instead.

In terms of code, you can execute the following before the denoising loop:

lora1 = self.apply_lora('lora1.safetensors', 0.0)
lora2 = self.apply_lora('lora2.safetensors', 0.0)

And within the denoising loop, dynamically perform operations like lora1.alpha = 1.0.

@alexblattner
Copy link

@takuma104 thank you very much for your answer. I didn't know that the Lora implementations were that different between diffusers and A111.

I love the fact that diffusers is so organized and doesn't require me to install a shitty webui so I have stuck to it. For my project I do need to use the latest tech though. Hopefully diffusers catches up soon. I am also creating a pipeline that should make latent couple (two shots) usable with multicontrolnet. I made latent couples work, I am applying controlnet and dynamic Loras on top of it now.

Does the code you gave:

lora1 = self.apply_lora('lora1.safetensors', 0.0)
lora2 = self.apply_lora('lora2.safetensors', 0.0)

Work with your code? If yes, then you saved me a fuckton of time and thanks a lot for it. If not, I'll try to do something about it assuming I manage to make my pipeline work with multicontrolnet.

@takuma104
Copy link
Author

@alexblattner Indeed, the code for Diffusers is well-organized. It should work this code in the StableDiffusionControlNetPipeline as well. Since LoRAs are only applied to pipe.unet and pipe.text_encoder. I hope your latent couple project goes smoothly!

@alexblattner
Copy link

thank you @takuma104 !

@alexblattner
Copy link

alexblattner commented Jun 27, 2023

I fixed an issue with torch type:


import math
import safetensors
import torch
from diffusers import DiffusionPipeline

# modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17
class LoRAModule(torch.nn.Module):
    def __init__(
        self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0
    ):
        """if alpha == 0 or None, alpha is rank (no scaling)."""
        super().__init__()

        if org_module.__class__.__name__ == "Conv2d":
            in_dim = org_module.in_channels
            out_dim = org_module.out_channels
        else:
            in_dim = org_module.in_features
            out_dim = org_module.out_features

        self.lora_dim = lora_dim

        if org_module.__class__.__name__ == "Conv2d":
            kernel_size = org_module.kernel_size
            stride = org_module.stride
            padding = org_module.padding
            self.lora_down = torch.nn.Conv2d(
                in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
            )
            self.lora_up = torch.nn.Conv2d(
                self.lora_dim, out_dim, (1, 1), (1, 1), bias=False
            )
        else:
            self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
            self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)

        if alpha is None or alpha == 0:
            self.alpha = self.lora_dim
        else:
            if type(alpha) == torch.Tensor:
                alpha = alpha.detach().float().numpy()  # without casting, bf16 causes error
            self.register_buffer("alpha", torch.tensor(alpha))  # Treatable as a constant.

        # same as microsoft's
        torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
        torch.nn.init.zeros_(self.lora_up.weight)

        self.multiplier = multiplier

    def forward(self, x):
        scale = self.alpha / self.lora_dim
        down= self.lora_down(x)
        up= self.lora_up(down)
        return self.multiplier * scale * up


class LoRAModuleContainer(torch.nn.Module):
    def __init__(self, hooks, state_dict, multiplier):
        super().__init__()
        self.multiplier = multiplier

        # Create LoRAModule from state_dict information
        for key, value in state_dict.items():
            if "lora_down" in key:
                lora_name = key.split(".")[0]
                lora_dim = value.size()[0]
                lora_name_alpha = key.split(".")[0] + '.alpha'
                alpha = None
                if lora_name_alpha in state_dict:
                    alpha = state_dict[lora_name_alpha].item()
                hook = hooks[lora_name]
                lora_module = LoRAModule(
                    hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier
                )
                self.register_module(lora_name, lora_module)

        # Load whole LoRA weights
        self.load_state_dict(state_dict)

        # Register LoRAModule to LoRAHook
        for name, module in self.named_modules():
            if module.__class__.__name__ == "LoRAModule":
                hook = hooks[name]
                hook.append_lora(module)
    @property
    def alpha(self):
        return self.multiplier

    @alpha.setter
    def alpha(self, multiplier):
        self.multiplier = multiplier
        for name, module in self.named_modules():
            if module.__class__.__name__ == "LoRAModule":
                module.multiplier = multiplier

    def remove_from_hooks(self, hooks):
        for name, module in self.named_modules():
            if module.__class__.__name__ == "LoRAModule":
                hook = hooks[name]
                hook.remove_lora(module)
                del module


class LoRAHook(torch.nn.Module):
    """
    replaces forward method of the original Linear,
    instead of replacing the original Linear module.
    """

    def __init__(self):
        super().__init__()
        self.lora_modules = []

    def install(self, orig_module):
        assert not hasattr(self, "orig_module")
        self.orig_module = orig_module
        self.orig_forward = self.orig_module.forward
        self.orig_module.forward = self.forward

    def uninstall(self):
        assert hasattr(self, "orig_module")
        self.orig_module.forward = self.orig_forward
        del self.orig_forward
        del self.orig_module

    def append_lora(self, lora_module):
        self.lora_modules.append(lora_module)

    def remove_lora(self, lora_module):
        self.lora_modules.remove(lora_module)

    def forward(self, x):
        if len(self.lora_modules) == 0:
            return self.orig_forward(x)
        lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0)
        return self.orig_forward(x) + lora


class LoRAHookInjector(object):
    def __init__(self):
        super().__init__()
        self.hooks = {}
        self.device = None
        self.dtype = None

    def _get_target_modules(self, root_module, prefix, target_replace_modules):
        target_modules = []
        for name, module in root_module.named_modules():
            if (
                module.__class__.__name__ in target_replace_modules
                and not "transformer_blocks" in name
            ):  # to adapt latest diffusers:
                for child_name, child_module in module.named_modules():
                    is_linear = child_module.__class__.__name__ == "Linear"
                    is_conv2d = child_module.__class__.__name__ == "Conv2d"
                    if is_linear or is_conv2d:
                        lora_name = prefix + "." + name + "." + child_name
                        lora_name = lora_name.replace(".", "_")
                        target_modules.append((lora_name, child_module))
        return target_modules

    def install_hooks(self, pipe):
        """Install LoRAHook to the pipe."""
        assert len(self.hooks) == 0
        text_encoder_targets = self._get_target_modules(
            pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]
        )
        unet_targets = self._get_target_modules(
            pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]
        )
        for name, target_module in text_encoder_targets + unet_targets:
            hook = LoRAHook()
            hook.install(target_module)
            self.hooks[name] = hook

        self.device = pipe.device
        self.dtype = pipe.unet.dtype

    def uninstall_hooks(self):
        """Uninstall LoRAHook from the pipe."""
        for k, v in self.hooks.items():
            v.uninstall()
        self.hooks = {}

    def apply_lora(self, filename, alpha=1.0,dtype=torch.float32):
        """Load LoRA weights and apply LoRA to the pipe."""
        assert len(self.hooks) != 0
        self.dtype = dtype
        state_dict = safetensors.torch.load_file(filename)
        container = LoRAModuleContainer(self.hooks, state_dict, alpha)
        container.to(self.device, self.dtype)
        return container

    def remove_lora(self, container):
        """Remove the individual LoRA from the pipe."""
        container.remove_from_hooks(self.hooks)


def install_lora_hook(pipe: DiffusionPipeline):
    """Install LoRAHook to the pipe."""
    assert not hasattr(pipe, "lora_injector")
    assert not hasattr(pipe, "apply_lora")
    assert not hasattr(pipe, "remove_lora")
    injector = LoRAHookInjector()
    injector.install_hooks(pipe)
    pipe.lora_injector = injector
    pipe.apply_lora = injector.apply_lora
    pipe.remove_lora = injector.remove_lora


def uninstall_lora_hook(pipe: DiffusionPipeline):
    """Uninstall LoRAHook from the pipe."""
    pipe.lora_injector.uninstall_hooks()
    del pipe.lora_injector
    del pipe.apply_lora
    del pipe.remove_lora

@adhikjoshi
Copy link

@takuma104 can you update clip_text_custom_embedder for SDXL?

@zoezhu
Copy link

zoezhu commented Oct 8, 2023

@takuma104 can you update clip_text_custom_embedder for SDXL please?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment