Skip to content

Instantly share code, notes, and snippets.

@vt-idiot
Created May 19, 2023 13:29
Show Gist options
  • Save vt-idiot/4b7ac2f899f4b2cfbaae1e59536248e1 to your computer and use it in GitHub Desktop.
Save vt-idiot/4b7ac2f899f4b2cfbaae1e59536248e1 to your computer and use it in GitHub Desktop.
Self Attention Guidance Patch
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 17109732..cd2ef1e9 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -53,7 +53,7 @@ class CFGDenoiserParams:
class CFGDenoisedParams:
- def __init__(self, x, sampling_step, total_sampling_steps):
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
self.x = x
"""Latent image representation in the process of being denoised"""
@@ -63,6 +63,19 @@ class CFGDenoisedParams:
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
+ self.inner_model = inner_model
+ """Inner model reference that is being used for denoising"""
+
+class AfterCFGCallbackParams:
+ def __init__(self, x, sampling_step, total_sampling_steps):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+ self.output_altered = False
+ """A flag for CFGDenoiser that indicates whether the output has been altered by the callback"""
class UiTrainTabParams:
def __init__(self, txt2img_preview_params):
@@ -87,6 +100,7 @@ callback_map = dict(
callbacks_image_saved=[],
callbacks_cfg_denoiser=[],
callbacks_cfg_denoised=[],
+ callbacks_cfg_after_cfg=[],
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
@@ -185,6 +199,12 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
except Exception:
report_exception(c, 'cfg_denoised_callback')
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
+ for c in callback_map['callbacks_cfg_after_cfg']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'cfg_after_cfg_callback')
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
@@ -331,13 +351,18 @@ def on_cfg_denoised(callback):
"""
add_callback(callback_map['callbacks_cfg_denoised'], callback)
+def on_cfg_after_cfg(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations has completed.
+ The callback is called with one argument:
+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
+ """
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
def on_before_component(callback):
"""register a function to be called before a component is created.
The callback is called with arguments:
- component - gradio component that is about to be created.
- **kwargs - args to gradio.components.IOComponent.__init__ function
-
Use elem_id/label fields of kwargs to figure out which component it is.
This can be useful to inject your own components somewhere in the middle of vanilla UI.
"""
@@ -369,11 +394,9 @@ def on_infotext_pasted(callback):
def on_script_unloaded(callback):
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
the script did should be reverted here"""
-
add_callback(callback_map['callbacks_script_unloaded'], callback)
def on_before_ui(callback):
"""register a function to be called before the UI is created."""
-
- add_callback(callback_map['callbacks_before_ui'], callback)
+ add_callback(callback_map['callbacks_before_ui'], callback)
\ No newline at end of file
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 0fc9f456..f90538a9 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -9,6 +9,7 @@ from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
@@ -161,7 +162,7 @@ class CFGDenoiser(torch.nn.Module):
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
@@ -181,6 +182,12 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+
+ if after_cfg_callback_params.output_altered:
+ denoised = after_cfg_callback_params.x
+
self.step += 1
return denoised
@@ -383,4 +390,3 @@ class KDiffusionSampler:
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
-
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment