Skip to content

Instantly share code, notes, and snippets.

@takuma104
Last active February 8, 2024 01:42
Show Gist options
  • Save takuma104/43552b8ec70b63323c57dc9c6fcb9b90 to your computer and use it in GitHub Desktop.
Save takuma104/43552b8ec70b63323c57dc9c6fcb9b90 to your computer and use it in GitHub Desktop.
clip_text_custom_embedder

Usage

from clip_text_custom_embedder import text_embeddings
from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to('cuda')

prompt = "((masterpiece, best quality)), white background, close-up, 1girl, litte smile"
negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), "
                   "bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), partial face, "
                   "partial head, cropped head")
cond, uncond = text_embeddings(pipe, prompt, negative_prompt, clip_stop_at_last_layers=2)

images = pipe(prompt_embeds=cond, 
              negative_prompt_embeds=uncond, 
              generator=torch.manual_seed(seed)).images[0]        
import torch
import math
import re
# copied and customized from automatic1111 sd_hijack.py & prompt_parser.py
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/ec1924ee5789b72c31c65932b549c59ccae0cdd6/modules/sd_hijack.py#L113
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/ec1924ee5789b72c31c65932b549c59ccae0cdd6/modules/prompt_parser.py#L259
re_attention = re.compile(r"""
\\\(|
\\\{|
\\\)|
\\\}|
\\\[|
\\]|
\\\\|
\\|
\(|
\{|
\[|
:([+-]?[.\d]+)\)|
\)|
\}|
]|
[^\\()\\{}\[\]:]+|
:
""", re.X)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith('\\'):
res.append([text[1:], 1.0])
elif text == '(' or text == '{':
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif (text == ')' or text == '}') and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
class CLIPTextCustomEmbedder(object):
def __init__(self, tokenizer, text_encoder, device,
clip_stop_at_last_layers=1):
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.token_mults = {}
self.device = device
self.clip_stop_at_last_layers = clip_stop_at_last_layers
def tokenize_line(self, line):
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
id_end = self.tokenizer.eos_token_id
parsed = parse_prompt_attention(line)
tokenized = self.tokenizer(
[text for text, _ in parsed], truncation=False,
add_special_tokens=False)["input_ids"]
fixes = []
remade_tokens = []
multipliers = []
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
remade_tokens.append(token)
multipliers.append(weight)
i += 1
token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * tokens_to_add
multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
def process_text(self, texts):
if isinstance(texts, str):
texts = [texts]
remade_batch_tokens = []
cache = {}
batch_multipliers = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, _ = self.tokenize_line(line)
cache[line] = (remade_tokens, fixes, multipliers)
remade_batch_tokens.append(remade_tokens)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens
def __call__(self, text):
batch_multipliers, remade_batch_tokens = self.process_text(text)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.tokenizer.eos_token_id] * 75)
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
remade_batch_tokens = [[self.tokenizer.bos_token_id] + x[:75] +
[self.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(self.device)
# print(tokens.shape)
# print(tokens)
outputs = self.text_encoder(
input_ids=tokens, output_hidden_states=True)
if self.clip_stop_at_last_layers > 1:
z = self.text_encoder.text_model.final_layer_norm(
outputs.hidden_states[-self.clip_stop_at_last_layers])
else:
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well
# to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [
x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(
batch_multipliers_of_same_length).to(self.device)
# print(batch_multipliers.shape)
# print(batch_multipliers)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape +
(1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z
def get_text_tokens(self, text):
batch_multipliers, remade_batch_tokens = self.process_text(text)
return [[self.tokenizer.bos_token_id] + remade_batch_tokens[0]], \
[[1.0] + batch_multipliers[0]]
def text_embeddings_equal_len(text_embedder, prompt, negative_prompt):
cond_embeddings = text_embedder(prompt)
uncond_embeddings = text_embedder(negative_prompt)
cond_len = cond_embeddings.shape[1]
uncond_len = uncond_embeddings.shape[1]
if cond_len == uncond_len:
return cond_embeddings, uncond_embeddings
else:
if cond_len > uncond_len:
n = (cond_len - uncond_len) // 77
return cond_embeddings, torch.cat([uncond_embeddings] + [text_embedder("")]*n, dim=1)
else:
n = (uncond_len - cond_len) // 77
return torch.cat([cond_embeddings] + [text_embedder("")]*n, dim=1), uncond_embeddings
def text_embeddings(pipe, prompt, negative_prompt, clip_stop_at_last_layers=1):
text_embedder = CLIPTextCustomEmbedder(tokenizer=pipe.tokenizer,
text_encoder=pipe.text_encoder,
device=pipe.text_encoder.device,
clip_stop_at_last_layers=clip_stop_at_last_layers)
cond_embeddings, uncond_embeddings = text_embeddings_equal_len(text_embedder, prompt, negative_prompt)
return cond_embeddings, uncond_embeddings
@takuma104
Copy link
Author

@alexblattner Indeed, the code for Diffusers is well-organized. It should work this code in the StableDiffusionControlNetPipeline as well. Since LoRAs are only applied to pipe.unet and pipe.text_encoder. I hope your latent couple project goes smoothly!

@alexblattner
Copy link

thank you @takuma104 !

@alexblattner
Copy link

alexblattner commented Jun 27, 2023

I fixed an issue with torch type:


import math
import safetensors
import torch
from diffusers import DiffusionPipeline

# modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17
class LoRAModule(torch.nn.Module):
    def __init__(
        self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0
    ):
        """if alpha == 0 or None, alpha is rank (no scaling)."""
        super().__init__()

        if org_module.__class__.__name__ == "Conv2d":
            in_dim = org_module.in_channels
            out_dim = org_module.out_channels
        else:
            in_dim = org_module.in_features
            out_dim = org_module.out_features

        self.lora_dim = lora_dim

        if org_module.__class__.__name__ == "Conv2d":
            kernel_size = org_module.kernel_size
            stride = org_module.stride
            padding = org_module.padding
            self.lora_down = torch.nn.Conv2d(
                in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
            )
            self.lora_up = torch.nn.Conv2d(
                self.lora_dim, out_dim, (1, 1), (1, 1), bias=False
            )
        else:
            self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
            self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)

        if alpha is None or alpha == 0:
            self.alpha = self.lora_dim
        else:
            if type(alpha) == torch.Tensor:
                alpha = alpha.detach().float().numpy()  # without casting, bf16 causes error
            self.register_buffer("alpha", torch.tensor(alpha))  # Treatable as a constant.

        # same as microsoft's
        torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
        torch.nn.init.zeros_(self.lora_up.weight)

        self.multiplier = multiplier

    def forward(self, x):
        scale = self.alpha / self.lora_dim
        down= self.lora_down(x)
        up= self.lora_up(down)
        return self.multiplier * scale * up


class LoRAModuleContainer(torch.nn.Module):
    def __init__(self, hooks, state_dict, multiplier):
        super().__init__()
        self.multiplier = multiplier

        # Create LoRAModule from state_dict information
        for key, value in state_dict.items():
            if "lora_down" in key:
                lora_name = key.split(".")[0]
                lora_dim = value.size()[0]
                lora_name_alpha = key.split(".")[0] + '.alpha'
                alpha = None
                if lora_name_alpha in state_dict:
                    alpha = state_dict[lora_name_alpha].item()
                hook = hooks[lora_name]
                lora_module = LoRAModule(
                    hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier
                )
                self.register_module(lora_name, lora_module)

        # Load whole LoRA weights
        self.load_state_dict(state_dict)

        # Register LoRAModule to LoRAHook
        for name, module in self.named_modules():
            if module.__class__.__name__ == "LoRAModule":
                hook = hooks[name]
                hook.append_lora(module)
    @property
    def alpha(self):
        return self.multiplier

    @alpha.setter
    def alpha(self, multiplier):
        self.multiplier = multiplier
        for name, module in self.named_modules():
            if module.__class__.__name__ == "LoRAModule":
                module.multiplier = multiplier

    def remove_from_hooks(self, hooks):
        for name, module in self.named_modules():
            if module.__class__.__name__ == "LoRAModule":
                hook = hooks[name]
                hook.remove_lora(module)
                del module


class LoRAHook(torch.nn.Module):
    """
    replaces forward method of the original Linear,
    instead of replacing the original Linear module.
    """

    def __init__(self):
        super().__init__()
        self.lora_modules = []

    def install(self, orig_module):
        assert not hasattr(self, "orig_module")
        self.orig_module = orig_module
        self.orig_forward = self.orig_module.forward
        self.orig_module.forward = self.forward

    def uninstall(self):
        assert hasattr(self, "orig_module")
        self.orig_module.forward = self.orig_forward
        del self.orig_forward
        del self.orig_module

    def append_lora(self, lora_module):
        self.lora_modules.append(lora_module)

    def remove_lora(self, lora_module):
        self.lora_modules.remove(lora_module)

    def forward(self, x):
        if len(self.lora_modules) == 0:
            return self.orig_forward(x)
        lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0)
        return self.orig_forward(x) + lora


class LoRAHookInjector(object):
    def __init__(self):
        super().__init__()
        self.hooks = {}
        self.device = None
        self.dtype = None

    def _get_target_modules(self, root_module, prefix, target_replace_modules):
        target_modules = []
        for name, module in root_module.named_modules():
            if (
                module.__class__.__name__ in target_replace_modules
                and not "transformer_blocks" in name
            ):  # to adapt latest diffusers:
                for child_name, child_module in module.named_modules():
                    is_linear = child_module.__class__.__name__ == "Linear"
                    is_conv2d = child_module.__class__.__name__ == "Conv2d"
                    if is_linear or is_conv2d:
                        lora_name = prefix + "." + name + "." + child_name
                        lora_name = lora_name.replace(".", "_")
                        target_modules.append((lora_name, child_module))
        return target_modules

    def install_hooks(self, pipe):
        """Install LoRAHook to the pipe."""
        assert len(self.hooks) == 0
        text_encoder_targets = self._get_target_modules(
            pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]
        )
        unet_targets = self._get_target_modules(
            pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]
        )
        for name, target_module in text_encoder_targets + unet_targets:
            hook = LoRAHook()
            hook.install(target_module)
            self.hooks[name] = hook

        self.device = pipe.device
        self.dtype = pipe.unet.dtype

    def uninstall_hooks(self):
        """Uninstall LoRAHook from the pipe."""
        for k, v in self.hooks.items():
            v.uninstall()
        self.hooks = {}

    def apply_lora(self, filename, alpha=1.0,dtype=torch.float32):
        """Load LoRA weights and apply LoRA to the pipe."""
        assert len(self.hooks) != 0
        self.dtype = dtype
        state_dict = safetensors.torch.load_file(filename)
        container = LoRAModuleContainer(self.hooks, state_dict, alpha)
        container.to(self.device, self.dtype)
        return container

    def remove_lora(self, container):
        """Remove the individual LoRA from the pipe."""
        container.remove_from_hooks(self.hooks)


def install_lora_hook(pipe: DiffusionPipeline):
    """Install LoRAHook to the pipe."""
    assert not hasattr(pipe, "lora_injector")
    assert not hasattr(pipe, "apply_lora")
    assert not hasattr(pipe, "remove_lora")
    injector = LoRAHookInjector()
    injector.install_hooks(pipe)
    pipe.lora_injector = injector
    pipe.apply_lora = injector.apply_lora
    pipe.remove_lora = injector.remove_lora


def uninstall_lora_hook(pipe: DiffusionPipeline):
    """Uninstall LoRAHook from the pipe."""
    pipe.lora_injector.uninstall_hooks()
    del pipe.lora_injector
    del pipe.apply_lora
    del pipe.remove_lora

@adhikjoshi
Copy link

@takuma104 can you update clip_text_custom_embedder for SDXL?

@zoezhu
Copy link

zoezhu commented Oct 8, 2023

@takuma104 can you update clip_text_custom_embedder for SDXL please?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment