Skip to content

Instantly share code, notes, and snippets.

@toriato
Created November 24, 2022 16:49
Show Gist options
  • Save toriato/3d1b2da54ef15c100e8996dd546da825 to your computer and use it in GitHub Desktop.
Save toriato/3d1b2da54ef15c100e8996dd546da825 to your computer and use it in GitHub Desktop.
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index eaedac1..67eac6c 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -4,6 +4,7 @@ import sys
import traceback
import torch
import numpy as np
+import open_clip
from torch import einsum
from torch.nn.functional import silu
@@ -70,9 +71,8 @@ class StableDiffusionModelHijack:
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
-
- model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
+ model = m.cond_stage_model.model
+ model.token_embedding = EmbeddingsWithFixes(model.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model
@@ -92,9 +92,10 @@ class StableDiffusionModelHijack:
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
- if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
- model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
+ model = m.cond_stage_model.model
+
+ if type(model.token_embedding) == EmbeddingsWithFixes:
+ model.token_embedding = model.token_embedding.wrapped
self.apply_circular(False)
self.layers = None
@@ -122,12 +123,15 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
super().__init__()
self.wrapped = wrapped
self.hijack: StableDiffusionModelHijack = hijack
- self.tokenizer = wrapped.tokenizer
+ self.tokenizer = open_clip.tokenizer._tokenizer # seems wrong
self.token_mults = {}
+
+ self.id_sot = self.tokenizer.encoder['<start_of_text>']
+ self.id_eot = self.tokenizer.encoder['<end_of_text>']
- self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
+ self.comma_token = [v for k, v in self.tokenizer.encoder.items() if k == ',</w>'][0]
- tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ tokens_with_parens = [(k, v) for k, v in self.tokenizer.encoder.items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
@@ -144,14 +148,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.token_mults[ident] = mult
def tokenize_line(self, line, used_custom_terms, hijack_comments):
- id_end = self.wrapped.tokenizer.eos_token_id
-
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
- tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
+ tokenized = list(map(self.tokenizer.encode, [text for text, _ in parsed]))
fixes = []
remade_tokens = []
@@ -176,7 +178,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length
- remade_tokens += [id_end] * rem + reloc_tokens
+ remade_tokens += [self.id_eot] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None:
@@ -188,7 +190,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
iteration = len(remade_tokens) // 75
if (len(remade_tokens) + emb_len) // 75 != iteration:
rem = (75 * (iteration + 1) - len(remade_tokens))
- remade_tokens += [id_end] * rem
+ remade_tokens += [self.id_eot] * rem
multipliers += [1.0] * rem
iteration += 1
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
@@ -201,7 +203,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens)
- remade_tokens = remade_tokens + [id_end] * tokens_to_add
+ remade_tokens = remade_tokens + [self.id_eot] * tokens_to_add
multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
@@ -348,17 +350,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
+ remade_batch_tokens = [[self.id_sot] + x[:75] + [self.id_eot] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(device)
- outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
+ z = self.wrapped.encode_with_transformer(tokens)
- if opts.CLIP_stop_at_last_layers > 1:
- z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
- z = self.wrapped.transformer.text_model.final_layer_norm(z)
- else:
- z = outputs.last_hidden_state
+ # if opts.CLIP_stop_at_last_layers > 1:
+ # z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
+ # z = self.wrapped.transformer.text_model.final_layer_norm(z)
+ # else:
+ # z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 4fe6785..fd5e092 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -33,6 +33,7 @@ samplers_k_diffusion = [
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
+ ('DPM++ stochastic', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {})
]
samplers_data_k_diffusion = [
@@ -350,7 +351,7 @@ class TorchHijack:
class KDiffusionSampler:
def __init__(self, funcname, sd_model):
- self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
+ self.model_wrap = k_diffusion.external.CompVisVDenoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, [])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment