|
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py |
|
index 17109732..3c21a362 100644 |
|
--- a/modules/script_callbacks.py |
|
+++ b/modules/script_callbacks.py |
|
@@ -32,27 +32,42 @@ class CFGDenoiserParams: |
|
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): |
|
self.x = x |
|
"""Latent image representation in the process of being denoised""" |
|
- |
|
+ |
|
self.image_cond = image_cond |
|
"""Conditioning image""" |
|
- |
|
+ |
|
self.sigma = sigma |
|
"""Current sigma noise step value""" |
|
- |
|
+ |
|
self.sampling_step = sampling_step |
|
"""Current Sampling step number""" |
|
- |
|
+ |
|
self.total_sampling_steps = total_sampling_steps |
|
"""Total number of sampling steps planned""" |
|
- |
|
+ |
|
self.text_cond = text_cond |
|
""" Encoder hidden states of text conditioning from prompt""" |
|
- |
|
+ |
|
self.text_uncond = text_uncond |
|
""" Encoder hidden states of text conditioning from negative prompt""" |
|
|
|
|
|
class CFGDenoisedParams: |
|
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model): |
|
+ self.x = x |
|
+ """Latent image representation in the process of being denoised""" |
|
+ |
|
+ self.sampling_step = sampling_step |
|
+ """Current Sampling step number""" |
|
+ |
|
+ self.total_sampling_steps = total_sampling_steps |
|
+ """Total number of sampling steps planned""" |
|
+ |
|
+ self.inner_model = inner_model |
|
+ """Inner model reference used for denoising""" |
|
+ |
|
+ |
|
+class AfterCFGCallbackParams: |
|
def __init__(self, x, sampling_step, total_sampling_steps): |
|
self.x = x |
|
"""Latent image representation in the process of being denoised""" |
|
@@ -87,6 +102,7 @@ callback_map = dict( |
|
callbacks_image_saved=[], |
|
callbacks_cfg_denoiser=[], |
|
callbacks_cfg_denoised=[], |
|
+ callbacks_cfg_after_cfg=[], |
|
callbacks_before_component=[], |
|
callbacks_after_component=[], |
|
callbacks_image_grid=[], |
|
@@ -186,6 +202,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams): |
|
report_exception(c, 'cfg_denoised_callback') |
|
|
|
|
|
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams): |
|
+ for c in callback_map['callbacks_cfg_after_cfg']: |
|
+ try: |
|
+ c.callback(params) |
|
+ except Exception: |
|
+ report_exception(c, 'cfg_after_cfg_callback') |
|
+ |
|
+ |
|
def before_component_callback(component, **kwargs): |
|
for c in callback_map['callbacks_before_component']: |
|
try: |
|
@@ -240,7 +264,7 @@ def add_callback(callbacks, fun): |
|
|
|
callbacks.append(ScriptCallback(filename, fun)) |
|
|
|
- |
|
+ |
|
def remove_current_script_callbacks(): |
|
stack = [x for x in inspect.stack() if x.filename != __file__] |
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file' |
|
@@ -332,6 +356,14 @@ def on_cfg_denoised(callback): |
|
add_callback(callback_map['callbacks_cfg_denoised'], callback) |
|
|
|
|
|
+def on_cfg_after_cfg(callback): |
|
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed. |
|
+ The callback is called with one argument: |
|
+ - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation. |
|
+ """ |
|
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback) |
|
+ |
|
+ |
|
def on_before_component(callback): |
|
"""register a function to be called before a component is created. |
|
The callback is called with arguments: |
|
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py |
|
index 0fc9f456..61f23ad7 100644 |
|
--- a/modules/sd_samplers_kdiffusion.py |
|
+++ b/modules/sd_samplers_kdiffusion.py |
|
@@ -1,7 +1,6 @@ |
|
from collections import deque |
|
import torch |
|
import inspect |
|
-import einops |
|
import k_diffusion.sampling |
|
from modules import prompt_parser, devices, sd_samplers_common |
|
|
|
@@ -9,6 +8,7 @@ from modules.shared import opts, state |
|
import modules.shared as shared |
|
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback |
|
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback |
|
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback |
|
|
|
samplers_k_diffusion = [ |
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), |
|
@@ -87,17 +87,17 @@ class CFGDenoiser(torch.nn.Module): |
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) |
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) |
|
|
|
- assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" |
|
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" |
|
|
|
batch_size = len(conds_list) |
|
repeats = [len(conds_list[i]) for i in range(batch_size)] |
|
|
|
if shared.sd_model.model.conditioning_key == "crossattn-adm": |
|
image_uncond = torch.zeros_like(image_cond) |
|
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} |
|
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} |
|
else: |
|
image_uncond = image_cond |
|
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} |
|
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} |
|
|
|
if not is_edit_model: |
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) |
|
@@ -161,7 +161,7 @@ class CFGDenoiser(torch.nn.Module): |
|
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) |
|
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be |
|
|
|
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) |
|
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) |
|
cfg_denoised_callback(denoised_params) |
|
|
|
devices.test_for_nans(x_out, "unet") |
|
@@ -181,6 +181,10 @@ class CFGDenoiser(torch.nn.Module): |
|
if self.mask is not None: |
|
denoised = self.init_latent * self.mask + self.nmask * denoised |
|
|
|
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) |
|
+ cfg_after_cfg_callback(after_cfg_callback_params) |
|
+ denoised = after_cfg_callback_params.x |
|
+ |
|
self.step += 1 |
|
return denoised |
|
|
|
@@ -317,7 +321,7 @@ class KDiffusionSampler: |
|
|
|
sigma_sched = sigmas[steps - t_enc - 1:] |
|
xi = x + noise * sigma_sched[0] |
|
- |
|
+ |
|
extra_params_kwargs = self.initialize(p) |
|
parameters = inspect.signature(self.func).parameters |
|
|
|
@@ -340,9 +344,9 @@ class KDiffusionSampler: |
|
self.model_wrap_cfg.init_latent = x |
|
self.last_latent = x |
|
extra_args={ |
|
- 'cond': conditioning, |
|
- 'image_cond': image_conditioning, |
|
- 'uncond': unconditional_conditioning, |
|
+ 'cond': conditioning, |
|
+ 'image_cond': image_conditioning, |
|
+ 'uncond': unconditional_conditioning, |
|
'cond_scale': p.cfg_scale, |
|
's_min_uncond': self.s_min_uncond |
|
} |
|
@@ -375,9 +379,9 @@ class KDiffusionSampler: |
|
|
|
self.last_latent = x |
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ |
|
- 'cond': conditioning, |
|
- 'image_cond': image_conditioning, |
|
- 'uncond': unconditional_conditioning, |
|
+ 'cond': conditioning, |
|
+ 'image_cond': image_conditioning, |
|
+ 'uncond': unconditional_conditioning, |
|
'cond_scale': p.cfg_scale, |
|
's_min_uncond': self.s_min_uncond |
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs)) |
Yeah patches only work for specified version of files.
Now webui v1.2.1 has been released. I have noticed that SAG has been patched into the
dev
branch on commit 307800143. So I guess we no longer need to patch it in the next master version of A1111 webui, probably v1.2.2 or v1.3.Given above information, I will skip using v1.2.1 on my device and just wait for the next release, so I am not motivated on patching v1.2.1. However, I still made a new SAG patch (
89f9faa63-83d66a6c (v1.2.1)
) by diffing the dev commit and the current v1.2.1 master commit, you can try it if you want. Please note that this v1.2.1 patch is NOT tested so use it at your own risk..