from clip_text_custom_embedder import text_embeddings
from diffusers import StableDiffusionPipeline
import torch
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to('cuda')
prompt = "((masterpiece, best quality)), white background, close-up, 1girl, litte smile"
negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), "
"bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), partial face, "
"partial head, cropped head")
cond, uncond = text_embeddings(pipe, prompt, negative_prompt, clip_stop_at_last_layers=2)
images = pipe(prompt_embeds=cond,
negative_prompt_embeds=uncond,
generator=torch.manual_seed(seed)).images[0]
Last active
February 8, 2024 01:42
-
-
Save takuma104/43552b8ec70b63323c57dc9c6fcb9b90 to your computer and use it in GitHub Desktop.
clip_text_custom_embedder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import math | |
import re | |
# copied and customized from automatic1111 sd_hijack.py & prompt_parser.py | |
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/ec1924ee5789b72c31c65932b549c59ccae0cdd6/modules/sd_hijack.py#L113 | |
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/ec1924ee5789b72c31c65932b549c59ccae0cdd6/modules/prompt_parser.py#L259 | |
re_attention = re.compile(r""" | |
\\\(| | |
\\\{| | |
\\\)| | |
\\\}| | |
\\\[| | |
\\]| | |
\\\\| | |
\\| | |
\(| | |
\{| | |
\[| | |
:([+-]?[.\d]+)\)| | |
\)| | |
\}| | |
]| | |
[^\\()\\{}\[\]:]+| | |
: | |
""", re.X) | |
def parse_prompt_attention(text): | |
""" | |
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight. | |
Accepted tokens are: | |
(abc) - increases attention to abc by a multiplier of 1.1 | |
(abc:3.12) - increases attention to abc by a multiplier of 3.12 | |
[abc] - decreases attention to abc by a multiplier of 1.1 | |
\( - literal character '(' | |
\[ - literal character '[' | |
\) - literal character ')' | |
\] - literal character ']' | |
\\ - literal character '\' | |
anything else - just text | |
>>> parse_prompt_attention('normal text') | |
[['normal text', 1.0]] | |
>>> parse_prompt_attention('an (important) word') | |
[['an ', 1.0], ['important', 1.1], [' word', 1.0]] | |
>>> parse_prompt_attention('(unbalanced') | |
[['unbalanced', 1.1]] | |
>>> parse_prompt_attention('\(literal\]') | |
[['(literal]', 1.0]] | |
>>> parse_prompt_attention('(unnecessary)(parens)') | |
[['unnecessaryparens', 1.1]] | |
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') | |
[['a ', 1.0], | |
['house', 1.5730000000000004], | |
[' ', 1.1], | |
['on', 1.0], | |
[' a ', 1.1], | |
['hill', 0.55], | |
[', sun, ', 1.1], | |
['sky', 1.4641000000000006], | |
['.', 1.1]] | |
""" | |
res = [] | |
round_brackets = [] | |
square_brackets = [] | |
round_bracket_multiplier = 1.1 | |
square_bracket_multiplier = 1 / 1.1 | |
def multiply_range(start_position, multiplier): | |
for p in range(start_position, len(res)): | |
res[p][1] *= multiplier | |
for m in re_attention.finditer(text): | |
text = m.group(0) | |
weight = m.group(1) | |
if text.startswith('\\'): | |
res.append([text[1:], 1.0]) | |
elif text == '(' or text == '{': | |
round_brackets.append(len(res)) | |
elif text == '[': | |
square_brackets.append(len(res)) | |
elif weight is not None and len(round_brackets) > 0: | |
multiply_range(round_brackets.pop(), float(weight)) | |
elif (text == ')' or text == '}') and len(round_brackets) > 0: | |
multiply_range(round_brackets.pop(), round_bracket_multiplier) | |
elif text == ']' and len(square_brackets) > 0: | |
multiply_range(square_brackets.pop(), square_bracket_multiplier) | |
else: | |
res.append([text, 1.0]) | |
for pos in round_brackets: | |
multiply_range(pos, round_bracket_multiplier) | |
for pos in square_brackets: | |
multiply_range(pos, square_bracket_multiplier) | |
if len(res) == 0: | |
res = [["", 1.0]] | |
# merge runs of identical weights | |
i = 0 | |
while i + 1 < len(res): | |
if res[i][1] == res[i + 1][1]: | |
res[i][0] += res[i + 1][0] | |
res.pop(i + 1) | |
else: | |
i += 1 | |
return res | |
class CLIPTextCustomEmbedder(object): | |
def __init__(self, tokenizer, text_encoder, device, | |
clip_stop_at_last_layers=1): | |
self.tokenizer = tokenizer | |
self.text_encoder = text_encoder | |
self.token_mults = {} | |
self.device = device | |
self.clip_stop_at_last_layers = clip_stop_at_last_layers | |
def tokenize_line(self, line): | |
def get_target_prompt_token_count(token_count): | |
return math.ceil(max(token_count, 1) / 75) * 75 | |
id_end = self.tokenizer.eos_token_id | |
parsed = parse_prompt_attention(line) | |
tokenized = self.tokenizer( | |
[text for text, _ in parsed], truncation=False, | |
add_special_tokens=False)["input_ids"] | |
fixes = [] | |
remade_tokens = [] | |
multipliers = [] | |
for tokens, (text, weight) in zip(tokenized, parsed): | |
i = 0 | |
while i < len(tokens): | |
token = tokens[i] | |
remade_tokens.append(token) | |
multipliers.append(weight) | |
i += 1 | |
token_count = len(remade_tokens) | |
prompt_target_length = get_target_prompt_token_count(token_count) | |
tokens_to_add = prompt_target_length - len(remade_tokens) | |
remade_tokens = remade_tokens + [id_end] * tokens_to_add | |
multipliers = multipliers + [1.0] * tokens_to_add | |
return remade_tokens, fixes, multipliers, token_count | |
def process_text(self, texts): | |
if isinstance(texts, str): | |
texts = [texts] | |
remade_batch_tokens = [] | |
cache = {} | |
batch_multipliers = [] | |
for line in texts: | |
if line in cache: | |
remade_tokens, fixes, multipliers = cache[line] | |
else: | |
remade_tokens, fixes, multipliers, _ = self.tokenize_line(line) | |
cache[line] = (remade_tokens, fixes, multipliers) | |
remade_batch_tokens.append(remade_tokens) | |
batch_multipliers.append(multipliers) | |
return batch_multipliers, remade_batch_tokens | |
def __call__(self, text): | |
batch_multipliers, remade_batch_tokens = self.process_text(text) | |
z = None | |
i = 0 | |
while max(map(len, remade_batch_tokens)) != 0: | |
rem_tokens = [x[75:] for x in remade_batch_tokens] | |
rem_multipliers = [x[75:] for x in batch_multipliers] | |
tokens = [] | |
multipliers = [] | |
for j in range(len(remade_batch_tokens)): | |
if len(remade_batch_tokens[j]) > 0: | |
tokens.append(remade_batch_tokens[j][:75]) | |
multipliers.append(batch_multipliers[j][:75]) | |
else: | |
tokens.append([self.tokenizer.eos_token_id] * 75) | |
multipliers.append([1.0] * 75) | |
z1 = self.process_tokens(tokens, multipliers) | |
z = z1 if z is None else torch.cat((z, z1), axis=-2) | |
remade_batch_tokens = rem_tokens | |
batch_multipliers = rem_multipliers | |
i += 1 | |
return z | |
def process_tokens(self, remade_batch_tokens, batch_multipliers): | |
remade_batch_tokens = [[self.tokenizer.bos_token_id] + x[:75] + | |
[self.tokenizer.eos_token_id] for x in remade_batch_tokens] | |
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] | |
tokens = torch.asarray(remade_batch_tokens).to(self.device) | |
# print(tokens.shape) | |
# print(tokens) | |
outputs = self.text_encoder( | |
input_ids=tokens, output_hidden_states=True) | |
if self.clip_stop_at_last_layers > 1: | |
z = self.text_encoder.text_model.final_layer_norm( | |
outputs.hidden_states[-self.clip_stop_at_last_layers]) | |
else: | |
z = outputs.last_hidden_state | |
# restoring original mean is likely not correct, but it seems to work well | |
# to prevent artifacts that happen otherwise | |
batch_multipliers_of_same_length = [ | |
x + [1.0] * (75 - len(x)) for x in batch_multipliers] | |
batch_multipliers = torch.asarray( | |
batch_multipliers_of_same_length).to(self.device) | |
# print(batch_multipliers.shape) | |
# print(batch_multipliers) | |
original_mean = z.mean() | |
z *= batch_multipliers.reshape(batch_multipliers.shape + | |
(1,)).expand(z.shape) | |
new_mean = z.mean() | |
z *= original_mean / new_mean | |
return z | |
def get_text_tokens(self, text): | |
batch_multipliers, remade_batch_tokens = self.process_text(text) | |
return [[self.tokenizer.bos_token_id] + remade_batch_tokens[0]], \ | |
[[1.0] + batch_multipliers[0]] | |
def text_embeddings_equal_len(text_embedder, prompt, negative_prompt): | |
cond_embeddings = text_embedder(prompt) | |
uncond_embeddings = text_embedder(negative_prompt) | |
cond_len = cond_embeddings.shape[1] | |
uncond_len = uncond_embeddings.shape[1] | |
if cond_len == uncond_len: | |
return cond_embeddings, uncond_embeddings | |
else: | |
if cond_len > uncond_len: | |
n = (cond_len - uncond_len) // 77 | |
return cond_embeddings, torch.cat([uncond_embeddings] + [text_embedder("")]*n, dim=1) | |
else: | |
n = (uncond_len - cond_len) // 77 | |
return torch.cat([cond_embeddings] + [text_embedder("")]*n, dim=1), uncond_embeddings | |
def text_embeddings(pipe, prompt, negative_prompt, clip_stop_at_last_layers=1): | |
text_embedder = CLIPTextCustomEmbedder(tokenizer=pipe.tokenizer, | |
text_encoder=pipe.text_encoder, | |
device=pipe.text_encoder.device, | |
clip_stop_at_last_layers=clip_stop_at_last_layers) | |
cond_embeddings, uncond_embeddings = text_embeddings_equal_len(text_embedder, prompt, negative_prompt) | |
return cond_embeddings, uncond_embeddings | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@takuma104 can you update clip_text_custom_embedder for SDXL please?