|
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py |
|
index 17109732..3c21a362 100644 |
|
--- a/modules/script_callbacks.py |
|
+++ b/modules/script_callbacks.py |
|
@@ -32,27 +32,42 @@ class CFGDenoiserParams: |
|
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): |
|
self.x = x |
|
"""Latent image representation in the process of being denoised""" |
|
- |
|
+ |
|
self.image_cond = image_cond |
|
"""Conditioning image""" |
|
- |
|
+ |
|
self.sigma = sigma |
|
"""Current sigma noise step value""" |
|
- |
|
+ |
|
self.sampling_step = sampling_step |
|
"""Current Sampling step number""" |
|
- |
|
+ |
|
self.total_sampling_steps = total_sampling_steps |
|
"""Total number of sampling steps planned""" |
|
- |
|
+ |
|
self.text_cond = text_cond |
|
""" Encoder hidden states of text conditioning from prompt""" |
|
- |
|
+ |
|
self.text_uncond = text_uncond |
|
""" Encoder hidden states of text conditioning from negative prompt""" |
|
|
|
|
|
class CFGDenoisedParams: |
|
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model): |
|
+ self.x = x |
|
+ """Latent image representation in the process of being denoised""" |
|
+ |
|
+ self.sampling_step = sampling_step |
|
+ """Current Sampling step number""" |
|
+ |
|
+ self.total_sampling_steps = total_sampling_steps |
|
+ """Total number of sampling steps planned""" |
|
+ |
|
+ self.inner_model = inner_model |
|
+ """Inner model reference used for denoising""" |
|
+ |
|
+ |
|
+class AfterCFGCallbackParams: |
|
def __init__(self, x, sampling_step, total_sampling_steps): |
|
self.x = x |
|
"""Latent image representation in the process of being denoised""" |
|
@@ -87,6 +102,7 @@ callback_map = dict( |
|
callbacks_image_saved=[], |
|
callbacks_cfg_denoiser=[], |
|
callbacks_cfg_denoised=[], |
|
+ callbacks_cfg_after_cfg=[], |
|
callbacks_before_component=[], |
|
callbacks_after_component=[], |
|
callbacks_image_grid=[], |
|
@@ -186,6 +202,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams): |
|
report_exception(c, 'cfg_denoised_callback') |
|
|
|
|
|
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams): |
|
+ for c in callback_map['callbacks_cfg_after_cfg']: |
|
+ try: |
|
+ c.callback(params) |
|
+ except Exception: |
|
+ report_exception(c, 'cfg_after_cfg_callback') |
|
+ |
|
+ |
|
def before_component_callback(component, **kwargs): |
|
for c in callback_map['callbacks_before_component']: |
|
try: |
|
@@ -240,7 +264,7 @@ def add_callback(callbacks, fun): |
|
|
|
callbacks.append(ScriptCallback(filename, fun)) |
|
|
|
- |
|
+ |
|
def remove_current_script_callbacks(): |
|
stack = [x for x in inspect.stack() if x.filename != __file__] |
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file' |
|
@@ -332,6 +356,14 @@ def on_cfg_denoised(callback): |
|
add_callback(callback_map['callbacks_cfg_denoised'], callback) |
|
|
|
|
|
+def on_cfg_after_cfg(callback): |
|
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed. |
|
+ The callback is called with one argument: |
|
+ - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation. |
|
+ """ |
|
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback) |
|
+ |
|
+ |
|
def on_before_component(callback): |
|
"""register a function to be called before a component is created. |
|
The callback is called with arguments: |
|
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py |
|
index 0fc9f456..61f23ad7 100644 |
|
--- a/modules/sd_samplers_kdiffusion.py |
|
+++ b/modules/sd_samplers_kdiffusion.py |
|
@@ -1,7 +1,6 @@ |
|
from collections import deque |
|
import torch |
|
import inspect |
|
-import einops |
|
import k_diffusion.sampling |
|
from modules import prompt_parser, devices, sd_samplers_common |
|
|
|
@@ -9,6 +8,7 @@ from modules.shared import opts, state |
|
import modules.shared as shared |
|
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback |
|
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback |
|
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback |
|
|
|
samplers_k_diffusion = [ |
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), |
|
@@ -87,17 +87,17 @@ class CFGDenoiser(torch.nn.Module): |
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) |
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) |
|
|
|
- assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" |
|
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" |
|
|
|
batch_size = len(conds_list) |
|
repeats = [len(conds_list[i]) for i in range(batch_size)] |
|
|
|
if shared.sd_model.model.conditioning_key == "crossattn-adm": |
|
image_uncond = torch.zeros_like(image_cond) |
|
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} |
|
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} |
|
else: |
|
image_uncond = image_cond |
|
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} |
|
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} |
|
|
|
if not is_edit_model: |
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) |
|
@@ -161,7 +161,7 @@ class CFGDenoiser(torch.nn.Module): |
|
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) |
|
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be |
|
|
|
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) |
|
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) |
|
cfg_denoised_callback(denoised_params) |
|
|
|
devices.test_for_nans(x_out, "unet") |
|
@@ -181,6 +181,10 @@ class CFGDenoiser(torch.nn.Module): |
|
if self.mask is not None: |
|
denoised = self.init_latent * self.mask + self.nmask * denoised |
|
|
|
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) |
|
+ cfg_after_cfg_callback(after_cfg_callback_params) |
|
+ denoised = after_cfg_callback_params.x |
|
+ |
|
self.step += 1 |
|
return denoised |
|
|
|
@@ -317,7 +321,7 @@ class KDiffusionSampler: |
|
|
|
sigma_sched = sigmas[steps - t_enc - 1:] |
|
xi = x + noise * sigma_sched[0] |
|
- |
|
+ |
|
extra_params_kwargs = self.initialize(p) |
|
parameters = inspect.signature(self.func).parameters |
|
|
|
@@ -340,9 +344,9 @@ class KDiffusionSampler: |
|
self.model_wrap_cfg.init_latent = x |
|
self.last_latent = x |
|
extra_args={ |
|
- 'cond': conditioning, |
|
- 'image_cond': image_conditioning, |
|
- 'uncond': unconditional_conditioning, |
|
+ 'cond': conditioning, |
|
+ 'image_cond': image_conditioning, |
|
+ 'uncond': unconditional_conditioning, |
|
'cond_scale': p.cfg_scale, |
|
's_min_uncond': self.s_min_uncond |
|
} |
|
@@ -375,9 +379,9 @@ class KDiffusionSampler: |
|
|
|
self.last_latent = x |
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ |
|
- 'cond': conditioning, |
|
- 'image_cond': image_conditioning, |
|
- 'uncond': unconditional_conditioning, |
|
+ 'cond': conditioning, |
|
+ 'image_cond': image_conditioning, |
|
+ 'uncond': unconditional_conditioning, |
|
'cond_scale': p.cfg_scale, |
|
's_min_uncond': self.s_min_uncond |
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs)) |
is this working with the latest version of A1111?