from clip_text_custom_embedder import text_embeddings
from diffusers import StableDiffusionPipeline
import torch
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to('cuda')
prompt = "((masterpiece, best quality)), white background, close-up, 1girl, litte smile"
negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), "
"bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), partial face, "
"partial head, cropped head")
cond, uncond = text_embeddings(pipe, prompt, negative_prompt, clip_stop_at_last_layers=2)
images = pipe(prompt_embeds=cond,
negative_prompt_embeds=uncond,
generator=torch.manual_seed(seed)).images[0]
Last active
February 8, 2024 01:42
-
-
Save takuma104/43552b8ec70b63323c57dc9c6fcb9b90 to your computer and use it in GitHub Desktop.
clip_text_custom_embedder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import math | |
import re | |
# copied and customized from automatic1111 sd_hijack.py & prompt_parser.py | |
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/ec1924ee5789b72c31c65932b549c59ccae0cdd6/modules/sd_hijack.py#L113 | |
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/ec1924ee5789b72c31c65932b549c59ccae0cdd6/modules/prompt_parser.py#L259 | |
re_attention = re.compile(r""" | |
\\\(| | |
\\\{| | |
\\\)| | |
\\\}| | |
\\\[| | |
\\]| | |
\\\\| | |
\\| | |
\(| | |
\{| | |
\[| | |
:([+-]?[.\d]+)\)| | |
\)| | |
\}| | |
]| | |
[^\\()\\{}\[\]:]+| | |
: | |
""", re.X) | |
def parse_prompt_attention(text): | |
""" | |
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight. | |
Accepted tokens are: | |
(abc) - increases attention to abc by a multiplier of 1.1 | |
(abc:3.12) - increases attention to abc by a multiplier of 3.12 | |
[abc] - decreases attention to abc by a multiplier of 1.1 | |
\( - literal character '(' | |
\[ - literal character '[' | |
\) - literal character ')' | |
\] - literal character ']' | |
\\ - literal character '\' | |
anything else - just text | |
>>> parse_prompt_attention('normal text') | |
[['normal text', 1.0]] | |
>>> parse_prompt_attention('an (important) word') | |
[['an ', 1.0], ['important', 1.1], [' word', 1.0]] | |
>>> parse_prompt_attention('(unbalanced') | |
[['unbalanced', 1.1]] | |
>>> parse_prompt_attention('\(literal\]') | |
[['(literal]', 1.0]] | |
>>> parse_prompt_attention('(unnecessary)(parens)') | |
[['unnecessaryparens', 1.1]] | |
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') | |
[['a ', 1.0], | |
['house', 1.5730000000000004], | |
[' ', 1.1], | |
['on', 1.0], | |
[' a ', 1.1], | |
['hill', 0.55], | |
[', sun, ', 1.1], | |
['sky', 1.4641000000000006], | |
['.', 1.1]] | |
""" | |
res = [] | |
round_brackets = [] | |
square_brackets = [] | |
round_bracket_multiplier = 1.1 | |
square_bracket_multiplier = 1 / 1.1 | |
def multiply_range(start_position, multiplier): | |
for p in range(start_position, len(res)): | |
res[p][1] *= multiplier | |
for m in re_attention.finditer(text): | |
text = m.group(0) | |
weight = m.group(1) | |
if text.startswith('\\'): | |
res.append([text[1:], 1.0]) | |
elif text == '(' or text == '{': | |
round_brackets.append(len(res)) | |
elif text == '[': | |
square_brackets.append(len(res)) | |
elif weight is not None and len(round_brackets) > 0: | |
multiply_range(round_brackets.pop(), float(weight)) | |
elif (text == ')' or text == '}') and len(round_brackets) > 0: | |
multiply_range(round_brackets.pop(), round_bracket_multiplier) | |
elif text == ']' and len(square_brackets) > 0: | |
multiply_range(square_brackets.pop(), square_bracket_multiplier) | |
else: | |
res.append([text, 1.0]) | |
for pos in round_brackets: | |
multiply_range(pos, round_bracket_multiplier) | |
for pos in square_brackets: | |
multiply_range(pos, square_bracket_multiplier) | |
if len(res) == 0: | |
res = [["", 1.0]] | |
# merge runs of identical weights | |
i = 0 | |
while i + 1 < len(res): | |
if res[i][1] == res[i + 1][1]: | |
res[i][0] += res[i + 1][0] | |
res.pop(i + 1) | |
else: | |
i += 1 | |
return res | |
class CLIPTextCustomEmbedder(object): | |
def __init__(self, tokenizer, text_encoder, device, | |
clip_stop_at_last_layers=1): | |
self.tokenizer = tokenizer | |
self.text_encoder = text_encoder | |
self.token_mults = {} | |
self.device = device | |
self.clip_stop_at_last_layers = clip_stop_at_last_layers | |
def tokenize_line(self, line): | |
def get_target_prompt_token_count(token_count): | |
return math.ceil(max(token_count, 1) / 75) * 75 | |
id_end = self.tokenizer.eos_token_id | |
parsed = parse_prompt_attention(line) | |
tokenized = self.tokenizer( | |
[text for text, _ in parsed], truncation=False, | |
add_special_tokens=False)["input_ids"] | |
fixes = [] | |
remade_tokens = [] | |
multipliers = [] | |
for tokens, (text, weight) in zip(tokenized, parsed): | |
i = 0 | |
while i < len(tokens): | |
token = tokens[i] | |
remade_tokens.append(token) | |
multipliers.append(weight) | |
i += 1 | |
token_count = len(remade_tokens) | |
prompt_target_length = get_target_prompt_token_count(token_count) | |
tokens_to_add = prompt_target_length - len(remade_tokens) | |
remade_tokens = remade_tokens + [id_end] * tokens_to_add | |
multipliers = multipliers + [1.0] * tokens_to_add | |
return remade_tokens, fixes, multipliers, token_count | |
def process_text(self, texts): | |
if isinstance(texts, str): | |
texts = [texts] | |
remade_batch_tokens = [] | |
cache = {} | |
batch_multipliers = [] | |
for line in texts: | |
if line in cache: | |
remade_tokens, fixes, multipliers = cache[line] | |
else: | |
remade_tokens, fixes, multipliers, _ = self.tokenize_line(line) | |
cache[line] = (remade_tokens, fixes, multipliers) | |
remade_batch_tokens.append(remade_tokens) | |
batch_multipliers.append(multipliers) | |
return batch_multipliers, remade_batch_tokens | |
def __call__(self, text): | |
batch_multipliers, remade_batch_tokens = self.process_text(text) | |
z = None | |
i = 0 | |
while max(map(len, remade_batch_tokens)) != 0: | |
rem_tokens = [x[75:] for x in remade_batch_tokens] | |
rem_multipliers = [x[75:] for x in batch_multipliers] | |
tokens = [] | |
multipliers = [] | |
for j in range(len(remade_batch_tokens)): | |
if len(remade_batch_tokens[j]) > 0: | |
tokens.append(remade_batch_tokens[j][:75]) | |
multipliers.append(batch_multipliers[j][:75]) | |
else: | |
tokens.append([self.tokenizer.eos_token_id] * 75) | |
multipliers.append([1.0] * 75) | |
z1 = self.process_tokens(tokens, multipliers) | |
z = z1 if z is None else torch.cat((z, z1), axis=-2) | |
remade_batch_tokens = rem_tokens | |
batch_multipliers = rem_multipliers | |
i += 1 | |
return z | |
def process_tokens(self, remade_batch_tokens, batch_multipliers): | |
remade_batch_tokens = [[self.tokenizer.bos_token_id] + x[:75] + | |
[self.tokenizer.eos_token_id] for x in remade_batch_tokens] | |
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] | |
tokens = torch.asarray(remade_batch_tokens).to(self.device) | |
# print(tokens.shape) | |
# print(tokens) | |
outputs = self.text_encoder( | |
input_ids=tokens, output_hidden_states=True) | |
if self.clip_stop_at_last_layers > 1: | |
z = self.text_encoder.text_model.final_layer_norm( | |
outputs.hidden_states[-self.clip_stop_at_last_layers]) | |
else: | |
z = outputs.last_hidden_state | |
# restoring original mean is likely not correct, but it seems to work well | |
# to prevent artifacts that happen otherwise | |
batch_multipliers_of_same_length = [ | |
x + [1.0] * (75 - len(x)) for x in batch_multipliers] | |
batch_multipliers = torch.asarray( | |
batch_multipliers_of_same_length).to(self.device) | |
# print(batch_multipliers.shape) | |
# print(batch_multipliers) | |
original_mean = z.mean() | |
z *= batch_multipliers.reshape(batch_multipliers.shape + | |
(1,)).expand(z.shape) | |
new_mean = z.mean() | |
z *= original_mean / new_mean | |
return z | |
def get_text_tokens(self, text): | |
batch_multipliers, remade_batch_tokens = self.process_text(text) | |
return [[self.tokenizer.bos_token_id] + remade_batch_tokens[0]], \ | |
[[1.0] + batch_multipliers[0]] | |
def text_embeddings_equal_len(text_embedder, prompt, negative_prompt): | |
cond_embeddings = text_embedder(prompt) | |
uncond_embeddings = text_embedder(negative_prompt) | |
cond_len = cond_embeddings.shape[1] | |
uncond_len = uncond_embeddings.shape[1] | |
if cond_len == uncond_len: | |
return cond_embeddings, uncond_embeddings | |
else: | |
if cond_len > uncond_len: | |
n = (cond_len - uncond_len) // 77 | |
return cond_embeddings, torch.cat([uncond_embeddings] + [text_embedder("")]*n, dim=1) | |
else: | |
n = (uncond_len - cond_len) // 77 | |
return torch.cat([cond_embeddings] + [text_embedder("")]*n, dim=1), uncond_embeddings | |
def text_embeddings(pipe, prompt, negative_prompt, clip_stop_at_last_layers=1): | |
text_embedder = CLIPTextCustomEmbedder(tokenizer=pipe.tokenizer, | |
text_encoder=pipe.text_encoder, | |
device=pipe.text_encoder.device, | |
clip_stop_at_last_layers=clip_stop_at_last_layers) | |
cond_embeddings, uncond_embeddings = text_embeddings_equal_len(text_embedder, prompt, negative_prompt) | |
return cond_embeddings, uncond_embeddings | |
@alexblattner Indeed, the code for Diffusers is well-organized. It should work this code in the StableDiffusionControlNetPipeline
as well. Since LoRAs are only applied to pipe.unet
and pipe.text_encoder
. I hope your latent couple project goes smoothly!
thank you @takuma104 !
I fixed an issue with torch type:
import math
import safetensors
import torch
from diffusers import DiffusionPipeline
# modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17
class LoRAModule(torch.nn.Module):
def __init__(
self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
)
self.lora_up = torch.nn.Conv2d(
self.lora_dim, out_dim, (1, 1), (1, 1), bias=False
)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
if alpha is None or alpha == 0:
self.alpha = self.lora_dim
else:
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
self.register_buffer("alpha", torch.tensor(alpha)) # Treatable as a constant.
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier
def forward(self, x):
scale = self.alpha / self.lora_dim
down= self.lora_down(x)
up= self.lora_up(down)
return self.multiplier * scale * up
class LoRAModuleContainer(torch.nn.Module):
def __init__(self, hooks, state_dict, multiplier):
super().__init__()
self.multiplier = multiplier
# Create LoRAModule from state_dict information
for key, value in state_dict.items():
if "lora_down" in key:
lora_name = key.split(".")[0]
lora_dim = value.size()[0]
lora_name_alpha = key.split(".")[0] + '.alpha'
alpha = None
if lora_name_alpha in state_dict:
alpha = state_dict[lora_name_alpha].item()
hook = hooks[lora_name]
lora_module = LoRAModule(
hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier
)
self.register_module(lora_name, lora_module)
# Load whole LoRA weights
self.load_state_dict(state_dict)
# Register LoRAModule to LoRAHook
for name, module in self.named_modules():
if module.__class__.__name__ == "LoRAModule":
hook = hooks[name]
hook.append_lora(module)
@property
def alpha(self):
return self.multiplier
@alpha.setter
def alpha(self, multiplier):
self.multiplier = multiplier
for name, module in self.named_modules():
if module.__class__.__name__ == "LoRAModule":
module.multiplier = multiplier
def remove_from_hooks(self, hooks):
for name, module in self.named_modules():
if module.__class__.__name__ == "LoRAModule":
hook = hooks[name]
hook.remove_lora(module)
del module
class LoRAHook(torch.nn.Module):
"""
replaces forward method of the original Linear,
instead of replacing the original Linear module.
"""
def __init__(self):
super().__init__()
self.lora_modules = []
def install(self, orig_module):
assert not hasattr(self, "orig_module")
self.orig_module = orig_module
self.orig_forward = self.orig_module.forward
self.orig_module.forward = self.forward
def uninstall(self):
assert hasattr(self, "orig_module")
self.orig_module.forward = self.orig_forward
del self.orig_forward
del self.orig_module
def append_lora(self, lora_module):
self.lora_modules.append(lora_module)
def remove_lora(self, lora_module):
self.lora_modules.remove(lora_module)
def forward(self, x):
if len(self.lora_modules) == 0:
return self.orig_forward(x)
lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0)
return self.orig_forward(x) + lora
class LoRAHookInjector(object):
def __init__(self):
super().__init__()
self.hooks = {}
self.device = None
self.dtype = None
def _get_target_modules(self, root_module, prefix, target_replace_modules):
target_modules = []
for name, module in root_module.named_modules():
if (
module.__class__.__name__ in target_replace_modules
and not "transformer_blocks" in name
): # to adapt latest diffusers:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
target_modules.append((lora_name, child_module))
return target_modules
def install_hooks(self, pipe):
"""Install LoRAHook to the pipe."""
assert len(self.hooks) == 0
text_encoder_targets = self._get_target_modules(
pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]
)
unet_targets = self._get_target_modules(
pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]
)
for name, target_module in text_encoder_targets + unet_targets:
hook = LoRAHook()
hook.install(target_module)
self.hooks[name] = hook
self.device = pipe.device
self.dtype = pipe.unet.dtype
def uninstall_hooks(self):
"""Uninstall LoRAHook from the pipe."""
for k, v in self.hooks.items():
v.uninstall()
self.hooks = {}
def apply_lora(self, filename, alpha=1.0,dtype=torch.float32):
"""Load LoRA weights and apply LoRA to the pipe."""
assert len(self.hooks) != 0
self.dtype = dtype
state_dict = safetensors.torch.load_file(filename)
container = LoRAModuleContainer(self.hooks, state_dict, alpha)
container.to(self.device, self.dtype)
return container
def remove_lora(self, container):
"""Remove the individual LoRA from the pipe."""
container.remove_from_hooks(self.hooks)
def install_lora_hook(pipe: DiffusionPipeline):
"""Install LoRAHook to the pipe."""
assert not hasattr(pipe, "lora_injector")
assert not hasattr(pipe, "apply_lora")
assert not hasattr(pipe, "remove_lora")
injector = LoRAHookInjector()
injector.install_hooks(pipe)
pipe.lora_injector = injector
pipe.apply_lora = injector.apply_lora
pipe.remove_lora = injector.remove_lora
def uninstall_lora_hook(pipe: DiffusionPipeline):
"""Uninstall LoRAHook from the pipe."""
pipe.lora_injector.uninstall_hooks()
del pipe.lora_injector
del pipe.apply_lora
del pipe.remove_lora
@takuma104 can you update clip_text_custom_embedder for SDXL?
@takuma104 can you update clip_text_custom_embedder for SDXL please?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@takuma104 thank you very much for your answer. I didn't know that the Lora implementations were that different between diffusers and A111.
I love the fact that diffusers is so organized and doesn't require me to install a shitty webui so I have stuck to it. For my project I do need to use the latest tech though. Hopefully diffusers catches up soon. I am also creating a pipeline that should make latent couple (two shots) usable with multicontrolnet. I made latent couples work, I am applying controlnet and dynamic Loras on top of it now.
Does the code you gave:
Work with your code? If yes, then you saved me a fuckton of time and thanks a lot for it. If not, I'll try to do something about it assuming I manage to make my pipeline work with multicontrolnet.