|
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py |
|
index 17109732..3c21a362 100644 |
|
--- a/modules/script_callbacks.py |
|
+++ b/modules/script_callbacks.py |
|
@@ -32,27 +32,42 @@ class CFGDenoiserParams: |
|
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): |
|
self.x = x |
|
"""Latent image representation in the process of being denoised""" |
|
- |
|
+ |
|
self.image_cond = image_cond |
|
"""Conditioning image""" |
|
- |
|
+ |
|
self.sigma = sigma |
|
"""Current sigma noise step value""" |
|
- |
|
+ |
|
self.sampling_step = sampling_step |
|
"""Current Sampling step number""" |
|
- |
|
+ |
|
self.total_sampling_steps = total_sampling_steps |
|
"""Total number of sampling steps planned""" |
|
- |
|
+ |
|
self.text_cond = text_cond |
|
""" Encoder hidden states of text conditioning from prompt""" |
|
- |
|
+ |
|
self.text_uncond = text_uncond |
|
""" Encoder hidden states of text conditioning from negative prompt""" |
|
|
|
|
|
class CFGDenoisedParams: |
|
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model): |
|
+ self.x = x |
|
+ """Latent image representation in the process of being denoised""" |
|
+ |
|
+ self.sampling_step = sampling_step |
|
+ """Current Sampling step number""" |
|
+ |
|
+ self.total_sampling_steps = total_sampling_steps |
|
+ """Total number of sampling steps planned""" |
|
+ |
|
+ self.inner_model = inner_model |
|
+ """Inner model reference used for denoising""" |
|
+ |
|
+ |
|
+class AfterCFGCallbackParams: |
|
def __init__(self, x, sampling_step, total_sampling_steps): |
|
self.x = x |
|
"""Latent image representation in the process of being denoised""" |
|
@@ -87,6 +102,7 @@ callback_map = dict( |
|
callbacks_image_saved=[], |
|
callbacks_cfg_denoiser=[], |
|
callbacks_cfg_denoised=[], |
|
+ callbacks_cfg_after_cfg=[], |
|
callbacks_before_component=[], |
|
callbacks_after_component=[], |
|
callbacks_image_grid=[], |
|
@@ -186,6 +202,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams): |
|
report_exception(c, 'cfg_denoised_callback') |
|
|
|
|
|
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams): |
|
+ for c in callback_map['callbacks_cfg_after_cfg']: |
|
+ try: |
|
+ c.callback(params) |
|
+ except Exception: |
|
+ report_exception(c, 'cfg_after_cfg_callback') |
|
+ |
|
+ |
|
def before_component_callback(component, **kwargs): |
|
for c in callback_map['callbacks_before_component']: |
|
try: |
|
@@ -240,7 +264,7 @@ def add_callback(callbacks, fun): |
|
|
|
callbacks.append(ScriptCallback(filename, fun)) |
|
|
|
- |
|
+ |
|
def remove_current_script_callbacks(): |
|
stack = [x for x in inspect.stack() if x.filename != __file__] |
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file' |
|
@@ -332,6 +356,14 @@ def on_cfg_denoised(callback): |
|
add_callback(callback_map['callbacks_cfg_denoised'], callback) |
|
|
|
|
|
+def on_cfg_after_cfg(callback): |
|
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed. |
|
+ The callback is called with one argument: |
|
+ - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation. |
|
+ """ |
|
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback) |
|
+ |
|
+ |
|
def on_before_component(callback): |
|
"""register a function to be called before a component is created. |
|
The callback is called with arguments: |
|
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py |
|
index 0fc9f456..61f23ad7 100644 |
|
--- a/modules/sd_samplers_kdiffusion.py |
|
+++ b/modules/sd_samplers_kdiffusion.py |
|
@@ -1,7 +1,6 @@ |
|
from collections import deque |
|
import torch |
|
import inspect |
|
-import einops |
|
import k_diffusion.sampling |
|
from modules import prompt_parser, devices, sd_samplers_common |
|
|
|
@@ -9,6 +8,7 @@ from modules.shared import opts, state |
|
import modules.shared as shared |
|
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback |
|
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback |
|
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback |
|
|
|
samplers_k_diffusion = [ |
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), |
|
@@ -87,17 +87,17 @@ class CFGDenoiser(torch.nn.Module): |
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) |
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) |
|
|
|
- assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" |
|
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" |
|
|
|
batch_size = len(conds_list) |
|
repeats = [len(conds_list[i]) for i in range(batch_size)] |
|
|
|
if shared.sd_model.model.conditioning_key == "crossattn-adm": |
|
image_uncond = torch.zeros_like(image_cond) |
|
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} |
|
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} |
|
else: |
|
image_uncond = image_cond |
|
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} |
|
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} |
|
|
|
if not is_edit_model: |
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) |
|
@@ -161,7 +161,7 @@ class CFGDenoiser(torch.nn.Module): |
|
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) |
|
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be |
|
|
|
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) |
|
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) |
|
cfg_denoised_callback(denoised_params) |
|
|
|
devices.test_for_nans(x_out, "unet") |
|
@@ -181,6 +181,10 @@ class CFGDenoiser(torch.nn.Module): |
|
if self.mask is not None: |
|
denoised = self.init_latent * self.mask + self.nmask * denoised |
|
|
|
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) |
|
+ cfg_after_cfg_callback(after_cfg_callback_params) |
|
+ denoised = after_cfg_callback_params.x |
|
+ |
|
self.step += 1 |
|
return denoised |
|
|
|
@@ -317,7 +321,7 @@ class KDiffusionSampler: |
|
|
|
sigma_sched = sigmas[steps - t_enc - 1:] |
|
xi = x + noise * sigma_sched[0] |
|
- |
|
+ |
|
extra_params_kwargs = self.initialize(p) |
|
parameters = inspect.signature(self.func).parameters |
|
|
|
@@ -340,9 +344,9 @@ class KDiffusionSampler: |
|
self.model_wrap_cfg.init_latent = x |
|
self.last_latent = x |
|
extra_args={ |
|
- 'cond': conditioning, |
|
- 'image_cond': image_conditioning, |
|
- 'uncond': unconditional_conditioning, |
|
+ 'cond': conditioning, |
|
+ 'image_cond': image_conditioning, |
|
+ 'uncond': unconditional_conditioning, |
|
'cond_scale': p.cfg_scale, |
|
's_min_uncond': self.s_min_uncond |
|
} |
|
@@ -375,9 +379,9 @@ class KDiffusionSampler: |
|
|
|
self.last_latent = x |
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ |
|
- 'cond': conditioning, |
|
- 'image_cond': image_conditioning, |
|
- 'uncond': unconditional_conditioning, |
|
+ 'cond': conditioning, |
|
+ 'image_cond': image_conditioning, |
|
+ 'uncond': unconditional_conditioning, |
|
'cond_scale': p.cfg_scale, |
|
's_min_uncond': self.s_min_uncond |
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs)) |
I guess no because these two patched files are changed from v1.1 to v1.2. The patch file
5ab7f213-83d66a6c.patch
is made for the matched webui and SAG extension version, but you may try applying the patch on the latest webui. I tried the latest A1111 (v1.2.0, commitb08500ce
) but it was buggy on my device, so I actually switched back to v1.1 and would wait for bug fixes.I also noticed that there are known issues reported by the community and A1111 is doing hotfix in the dev branch for this version, so I wonder how long the lifespan of v1.2.0 will be. Therefore in my opinion this version is not worth patching. If it becomes usable after bug fixes (probably in v1.2.1? At least not the current commit), I will try patching it at that time.