Skip to content

Instantly share code, notes, and snippets.

@takuma104
Last active February 8, 2024 01:42
Show Gist options
  • Save takuma104/43552b8ec70b63323c57dc9c6fcb9b90 to your computer and use it in GitHub Desktop.
Save takuma104/43552b8ec70b63323c57dc9c6fcb9b90 to your computer and use it in GitHub Desktop.
clip_text_custom_embedder

Usage

from clip_text_custom_embedder import text_embeddings
from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to('cuda')

prompt = "((masterpiece, best quality)), white background, close-up, 1girl, litte smile"
negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), "
                   "bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), partial face, "
                   "partial head, cropped head")
cond, uncond = text_embeddings(pipe, prompt, negative_prompt, clip_stop_at_last_layers=2)

images = pipe(prompt_embeds=cond, 
              negative_prompt_embeds=uncond, 
              generator=torch.manual_seed(seed)).images[0]        
import torch
import math
import re
# copied and customized from automatic1111 sd_hijack.py & prompt_parser.py
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/ec1924ee5789b72c31c65932b549c59ccae0cdd6/modules/sd_hijack.py#L113
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/ec1924ee5789b72c31c65932b549c59ccae0cdd6/modules/prompt_parser.py#L259
re_attention = re.compile(r"""
\\\(|
\\\{|
\\\)|
\\\}|
\\\[|
\\]|
\\\\|
\\|
\(|
\{|
\[|
:([+-]?[.\d]+)\)|
\)|
\}|
]|
[^\\()\\{}\[\]:]+|
:
""", re.X)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith('\\'):
res.append([text[1:], 1.0])
elif text == '(' or text == '{':
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif (text == ')' or text == '}') and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
class CLIPTextCustomEmbedder(object):
def __init__(self, tokenizer, text_encoder, device,
clip_stop_at_last_layers=1):
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.token_mults = {}
self.device = device
self.clip_stop_at_last_layers = clip_stop_at_last_layers
def tokenize_line(self, line):
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
id_end = self.tokenizer.eos_token_id
parsed = parse_prompt_attention(line)
tokenized = self.tokenizer(
[text for text, _ in parsed], truncation=False,
add_special_tokens=False)["input_ids"]
fixes = []
remade_tokens = []
multipliers = []
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
remade_tokens.append(token)
multipliers.append(weight)
i += 1
token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * tokens_to_add
multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
def process_text(self, texts):
if isinstance(texts, str):
texts = [texts]
remade_batch_tokens = []
cache = {}
batch_multipliers = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, _ = self.tokenize_line(line)
cache[line] = (remade_tokens, fixes, multipliers)
remade_batch_tokens.append(remade_tokens)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens
def __call__(self, text):
batch_multipliers, remade_batch_tokens = self.process_text(text)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.tokenizer.eos_token_id] * 75)
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
remade_batch_tokens = [[self.tokenizer.bos_token_id] + x[:75] +
[self.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(self.device)
# print(tokens.shape)
# print(tokens)
outputs = self.text_encoder(
input_ids=tokens, output_hidden_states=True)
if self.clip_stop_at_last_layers > 1:
z = self.text_encoder.text_model.final_layer_norm(
outputs.hidden_states[-self.clip_stop_at_last_layers])
else:
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well
# to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [
x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(
batch_multipliers_of_same_length).to(self.device)
# print(batch_multipliers.shape)
# print(batch_multipliers)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape +
(1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z
def get_text_tokens(self, text):
batch_multipliers, remade_batch_tokens = self.process_text(text)
return [[self.tokenizer.bos_token_id] + remade_batch_tokens[0]], \
[[1.0] + batch_multipliers[0]]
def text_embeddings_equal_len(text_embedder, prompt, negative_prompt):
cond_embeddings = text_embedder(prompt)
uncond_embeddings = text_embedder(negative_prompt)
cond_len = cond_embeddings.shape[1]
uncond_len = uncond_embeddings.shape[1]
if cond_len == uncond_len:
return cond_embeddings, uncond_embeddings
else:
if cond_len > uncond_len:
n = (cond_len - uncond_len) // 77
return cond_embeddings, torch.cat([uncond_embeddings] + [text_embedder("")]*n, dim=1)
else:
n = (uncond_len - cond_len) // 77
return torch.cat([cond_embeddings] + [text_embedder("")]*n, dim=1), uncond_embeddings
def text_embeddings(pipe, prompt, negative_prompt, clip_stop_at_last_layers=1):
text_embedder = CLIPTextCustomEmbedder(tokenizer=pipe.tokenizer,
text_encoder=pipe.text_encoder,
device=pipe.text_encoder.device,
clip_stop_at_last_layers=clip_stop_at_last_layers)
cond_embeddings, uncond_embeddings = text_embeddings_equal_len(text_embedder, prompt, negative_prompt)
return cond_embeddings, uncond_embeddings
@zoezhu
Copy link

zoezhu commented Oct 8, 2023

@takuma104 can you update clip_text_custom_embedder for SDXL please?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment