Created
November 24, 2022 16:49
-
-
Save toriato/3d1b2da54ef15c100e8996dd546da825 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py | |
index eaedac1..67eac6c 100644 | |
--- a/modules/sd_hijack.py | |
+++ b/modules/sd_hijack.py | |
@@ -4,6 +4,7 @@ import sys | |
import traceback | |
import torch | |
import numpy as np | |
+import open_clip | |
from torch import einsum | |
from torch.nn.functional import silu | |
@@ -70,9 +71,8 @@ class StableDiffusionModelHijack: | |
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) | |
def hijack(self, m): | |
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings | |
- | |
- model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) | |
+ model = m.cond_stage_model.model | |
+ model.token_embedding = EmbeddingsWithFixes(model.token_embedding, self) | |
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) | |
self.clip = m.cond_stage_model | |
@@ -92,9 +92,10 @@ class StableDiffusionModelHijack: | |
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords: | |
m.cond_stage_model = m.cond_stage_model.wrapped | |
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings | |
- if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: | |
- model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped | |
+ model = m.cond_stage_model.model | |
+ | |
+ if type(model.token_embedding) == EmbeddingsWithFixes: | |
+ model.token_embedding = model.token_embedding.wrapped | |
self.apply_circular(False) | |
self.layers = None | |
@@ -122,12 +123,15 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): | |
super().__init__() | |
self.wrapped = wrapped | |
self.hijack: StableDiffusionModelHijack = hijack | |
- self.tokenizer = wrapped.tokenizer | |
+ self.tokenizer = open_clip.tokenizer._tokenizer # seems wrong | |
self.token_mults = {} | |
+ | |
+ self.id_sot = self.tokenizer.encoder['<start_of_text>'] | |
+ self.id_eot = self.tokenizer.encoder['<end_of_text>'] | |
- self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0] | |
+ self.comma_token = [v for k, v in self.tokenizer.encoder.items() if k == ',</w>'][0] | |
- tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] | |
+ tokens_with_parens = [(k, v) for k, v in self.tokenizer.encoder.items() if '(' in k or ')' in k or '[' in k or ']' in k] | |
for text, ident in tokens_with_parens: | |
mult = 1.0 | |
for c in text: | |
@@ -144,14 +148,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): | |
self.token_mults[ident] = mult | |
def tokenize_line(self, line, used_custom_terms, hijack_comments): | |
- id_end = self.wrapped.tokenizer.eos_token_id | |
- | |
if opts.enable_emphasis: | |
parsed = prompt_parser.parse_prompt_attention(line) | |
else: | |
parsed = [[line, 1.0]] | |
- tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] | |
+ tokenized = list(map(self.tokenizer.encode, [text for text, _ in parsed])) | |
fixes = [] | |
remade_tokens = [] | |
@@ -176,7 +178,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): | |
length = len(remade_tokens) | |
rem = int(math.ceil(length / 75)) * 75 - length | |
- remade_tokens += [id_end] * rem + reloc_tokens | |
+ remade_tokens += [self.id_eot] * rem + reloc_tokens | |
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults | |
if embedding is None: | |
@@ -188,7 +190,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): | |
iteration = len(remade_tokens) // 75 | |
if (len(remade_tokens) + emb_len) // 75 != iteration: | |
rem = (75 * (iteration + 1) - len(remade_tokens)) | |
- remade_tokens += [id_end] * rem | |
+ remade_tokens += [self.id_eot] * rem | |
multipliers += [1.0] * rem | |
iteration += 1 | |
fixes.append((iteration, (len(remade_tokens) % 75, embedding))) | |
@@ -201,7 +203,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): | |
prompt_target_length = get_target_prompt_token_count(token_count) | |
tokens_to_add = prompt_target_length - len(remade_tokens) | |
- remade_tokens = remade_tokens + [id_end] * tokens_to_add | |
+ remade_tokens = remade_tokens + [self.id_eot] * tokens_to_add | |
multipliers = multipliers + [1.0] * tokens_to_add | |
return remade_tokens, fixes, multipliers, token_count | |
@@ -348,17 +350,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): | |
def process_tokens(self, remade_batch_tokens, batch_multipliers): | |
if not opts.use_old_emphasis_implementation: | |
- remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] | |
+ remade_batch_tokens = [[self.id_sot] + x[:75] + [self.id_eot] for x in remade_batch_tokens] | |
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] | |
tokens = torch.asarray(remade_batch_tokens).to(device) | |
- outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) | |
+ z = self.wrapped.encode_with_transformer(tokens) | |
- if opts.CLIP_stop_at_last_layers > 1: | |
- z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] | |
- z = self.wrapped.transformer.text_model.final_layer_norm(z) | |
- else: | |
- z = outputs.last_hidden_state | |
+ # if opts.CLIP_stop_at_last_layers > 1: | |
+ # z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] | |
+ # z = self.wrapped.transformer.text_model.final_layer_norm(z) | |
+ # else: | |
+ # z = outputs.last_hidden_state | |
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise | |
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] | |
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py | |
index 4fe6785..fd5e092 100644 | |
--- a/modules/sd_samplers.py | |
+++ b/modules/sd_samplers.py | |
@@ -33,6 +33,7 @@ samplers_k_diffusion = [ | |
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), | |
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), | |
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), | |
+ ('DPM++ stochastic', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}) | |
] | |
samplers_data_k_diffusion = [ | |
@@ -350,7 +351,7 @@ class TorchHijack: | |
class KDiffusionSampler: | |
def __init__(self, funcname, sd_model): | |
- self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization) | |
+ self.model_wrap = k_diffusion.external.CompVisVDenoiser(sd_model, quantize=shared.opts.enable_quantization) | |
self.funcname = funcname | |
self.func = getattr(k_diffusion.sampling, self.funcname) | |
self.extra_params = sampler_extra_params.get(funcname, []) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment