Skip to content

Instantly share code, notes, and snippets.

@wklchris
Last active July 29, 2023 20:13
Show Gist options
  • Save wklchris/f91cd54454f3ce2fccf505e4ab86ec1e to your computer and use it in GitHub Desktop.
Save wklchris/f91cd54454f3ce2fccf505e4ab86ec1e to your computer and use it in GitHub Desktop.
sd-webui-SAG-patch

Patch for webui extension sd-webui-SAG on A1111

This is an unofficial gist for patching the A1111's webui so that makes the extension sd_webui_SAG work.

Guide

To apply a patch to your webui, please:

  • First check your current A1111 webui version:

    $ git rev-parse HEAD
    <webui commit hash>
    

    And the SAG extension version:

    $ cd extensions/sd_webui_SAG && git rev-parse HEAD
    <SAG extension commit hash>
    
  • Download the patch to your webui folder.

    For example, if your A1111's webui commit (first 8 digits) is AAAAAAAA and SAG commit is BBBBBBBB, find following patch file in this gist and download it into your webui folder:

    sd-webui-SAG-patch-AAAAAAAA-BBBBBBBB.patch

    If the two commits of yours don't match these in the patch filename, applying the patch may not work.

  • In your webui folder, apply the patch by:

    git apply <patch file name>  
    

Patch

89f9faa63-83d66a6c (v1.2.1)

This patch is NOT tested. Use it at your own risk. If the patch doesn't work properly, you may use git restore to recover your patched files.

Date: 05/14/2023

This patch 89f9faa63-83d66a6c (click to jump) is made from the diff between dev's current commit f6a2a98f and the latest master release v1.2.1. The reason of using dev commit f6a2a98f is that I found that SAG has been patched a few dev commits earlier (commit 307800143) so I directly diff between the two versions.

The patch is expected to work for:

To apply the patch:

git apply sd-webui-SAG-patch-89f9faa63-83d66a6c.patch

I expected the next webui release will natively include this SAG patch and we don't need to manually make a patch file anymore.

5ab7f213-83d66a6c (v1.1)

Credit: I made this patch fully relied on User @papuSpartan's patch for dev branch. Original discussion is in SAG issue #13

Date: 2023/05/07

This patch sd-webui-SAG-patch-5ab7f213-83d66a6c (click to jump) is made for:

  • A111/stable-diffusion-webui's v1.1 version. Specifically, for the commit 5ab7f213 on master branch
  • SAG's extension at commit 83d66a6c

To apply the patch:

git apply sd-webui-SAG-patch-5ab7f213-83d66a6c.patch
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 17109732..6e767f4d 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -53,7 +53,7 @@ class CFGDenoiserParams:
class CFGDenoisedParams:
- def __init__(self, x, sampling_step, total_sampling_steps):
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
self.x = x
"""Latent image representation in the process of being denoised"""
@@ -63,6 +63,23 @@ class CFGDenoisedParams:
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
+ self.inner_model = inner_model
+ """Inner model reference that is being used for denoising"""
+
+
+
+class AfterCFGCallbackParams:
+ def __init__(self, x, sampling_step, total_sampling_steps):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+ self.output_altered = False
+ """A flag for CFGDenoiser that indicates whether the output has been altered by the callback"""
+
+
class UiTrainTabParams:
def __init__(self, txt2img_preview_params):
@@ -87,6 +104,7 @@ callback_map = dict(
callbacks_image_saved=[],
callbacks_cfg_denoiser=[],
callbacks_cfg_denoised=[],
+ callbacks_cfg_after_cfg=[],
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
@@ -186,6 +204,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
report_exception(c, 'cfg_denoised_callback')
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
+ for c in callback_map['callbacks_cfg_after_cfg']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'cfg_after_cfg_callback')
+
+
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
try:
@@ -332,6 +358,14 @@ def on_cfg_denoised(callback):
add_callback(callback_map['callbacks_cfg_denoised'], callback)
+def on_cfg_after_cfg(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations has completed.
+ The callback is called with one argument:
+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
+ """
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
+
+
def on_before_component(callback):
"""register a function to be called before a component is created.
The callback is called with arguments:
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index eb98e599..eefee346 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -9,6 +9,7 @@ from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
@@ -161,7 +162,7 @@ class CFGDenoiser(torch.nn.Module):
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
@@ -181,6 +182,12 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+
+ if after_cfg_callback_params.output_altered:
+ denoised = after_cfg_callback_params.x
+
self.step += 1
return denoised
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 17109732..3c21a362 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -32,27 +32,42 @@ class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
self.x = x
"""Latent image representation in the process of being denoised"""
-
+
self.image_cond = image_cond
"""Conditioning image"""
-
+
self.sigma = sigma
"""Current sigma noise step value"""
-
+
self.sampling_step = sampling_step
"""Current Sampling step number"""
-
+
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
-
+
self.text_cond = text_cond
""" Encoder hidden states of text conditioning from prompt"""
-
+
self.text_uncond = text_uncond
""" Encoder hidden states of text conditioning from negative prompt"""
class CFGDenoisedParams:
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.sampling_step = sampling_step
+ """Current Sampling step number"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+ self.inner_model = inner_model
+ """Inner model reference used for denoising"""
+
+
+class AfterCFGCallbackParams:
def __init__(self, x, sampling_step, total_sampling_steps):
self.x = x
"""Latent image representation in the process of being denoised"""
@@ -87,6 +102,7 @@ callback_map = dict(
callbacks_image_saved=[],
callbacks_cfg_denoiser=[],
callbacks_cfg_denoised=[],
+ callbacks_cfg_after_cfg=[],
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
@@ -186,6 +202,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
report_exception(c, 'cfg_denoised_callback')
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
+ for c in callback_map['callbacks_cfg_after_cfg']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'cfg_after_cfg_callback')
+
+
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
try:
@@ -240,7 +264,7 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun))
-
+
def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -332,6 +356,14 @@ def on_cfg_denoised(callback):
add_callback(callback_map['callbacks_cfg_denoised'], callback)
+def on_cfg_after_cfg(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
+ The callback is called with one argument:
+ - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
+ """
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
+
+
def on_before_component(callback):
"""register a function to be called before a component is created.
The callback is called with arguments:
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 0fc9f456..61f23ad7 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -1,7 +1,6 @@
from collections import deque
import torch
import inspect
-import einops
import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common
@@ -9,6 +8,7 @@ from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
@@ -87,17 +87,17 @@ class CFGDenoiser(torch.nn.Module):
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
- assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond)
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
else:
image_uncond = image_cond
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
@@ -161,7 +161,7 @@ class CFGDenoiser(torch.nn.Module):
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
@@ -181,6 +181,10 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+ denoised = after_cfg_callback_params.x
+
self.step += 1
return denoised
@@ -317,7 +321,7 @@ class KDiffusionSampler:
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
-
+
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
@@ -340,9 +344,9 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x
self.last_latent = x
extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}
@@ -375,9 +379,9 @@ class KDiffusionSampler:
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
@alenknight
Copy link

is this working with the latest version of A1111?

@wklchris
Copy link
Author

wklchris commented May 14, 2023

is this working with the latest version of A1111?

I guess no because these two patched files are changed from v1.1 to v1.2. The patch file 5ab7f213-83d66a6c.patch is made for the matched webui and SAG extension version, but you may try applying the patch on the latest webui. I tried the latest A1111 (v1.2.0, commit b08500ce) but it was buggy on my device, so I actually switched back to v1.1 and would wait for bug fixes.

I also noticed that there are known issues reported by the community and A1111 is doing hotfix in the dev branch for this version, so I wonder how long the lifespan of v1.2.0 will be. Therefore in my opinion this version is not worth patching. If it becomes usable after bug fixes (probably in v1.2.1? At least not the current commit), I will try patching it at that time.

@alenknight
Copy link

i'm a bit confused by this.... who has A1111 that has all 'A's? AAAAAAAA and SAG commit is BBBBBBBB? it's usually something like 5abcd###

@alenknight
Copy link

ah ok. yeah i'm getting an error again - error: corrupt patch at line 109

@wklchris
Copy link
Author

ah ok. yeah i'm getting an error again - error: corrupt patch at line 109

Yeah patches only work for specified version of files.

Now webui v1.2.1 has been released. I have noticed that SAG has been patched into the dev branch on commit 307800143. So I guess we no longer need to patch it in the next master version of A1111 webui, probably v1.2.2 or v1.3.

Given above information, I will skip using v1.2.1 on my device and just wait for the next release, so I am not motivated on patching v1.2.1. However, I still made a new SAG patch (89f9faa63-83d66a6c (v1.2.1)) by diffing the dev commit and the current v1.2.1 master commit, you can try it if you want. Please note that this v1.2.1 patch is NOT tested so use it at your own risk..

@alenknight
Copy link

understood. thanks again...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment