-
-
Save vt-idiot/4b7ac2f899f4b2cfbaae1e59536248e1 to your computer and use it in GitHub Desktop.
Self Attention Guidance Patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py | |
index 17109732..cd2ef1e9 100644 | |
--- a/modules/script_callbacks.py | |
+++ b/modules/script_callbacks.py | |
@@ -53,7 +53,7 @@ class CFGDenoiserParams: | |
class CFGDenoisedParams: | |
- def __init__(self, x, sampling_step, total_sampling_steps): | |
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model): | |
self.x = x | |
"""Latent image representation in the process of being denoised""" | |
@@ -63,6 +63,19 @@ class CFGDenoisedParams: | |
self.total_sampling_steps = total_sampling_steps | |
"""Total number of sampling steps planned""" | |
+ self.inner_model = inner_model | |
+ """Inner model reference that is being used for denoising""" | |
+ | |
+class AfterCFGCallbackParams: | |
+ def __init__(self, x, sampling_step, total_sampling_steps): | |
+ self.x = x | |
+ """Latent image representation in the process of being denoised""" | |
+ | |
+ self.total_sampling_steps = total_sampling_steps | |
+ """Total number of sampling steps planned""" | |
+ | |
+ self.output_altered = False | |
+ """A flag for CFGDenoiser that indicates whether the output has been altered by the callback""" | |
class UiTrainTabParams: | |
def __init__(self, txt2img_preview_params): | |
@@ -87,6 +100,7 @@ callback_map = dict( | |
callbacks_image_saved=[], | |
callbacks_cfg_denoiser=[], | |
callbacks_cfg_denoised=[], | |
+ callbacks_cfg_after_cfg=[], | |
callbacks_before_component=[], | |
callbacks_after_component=[], | |
callbacks_image_grid=[], | |
@@ -185,6 +199,12 @@ def cfg_denoised_callback(params: CFGDenoisedParams): | |
except Exception: | |
report_exception(c, 'cfg_denoised_callback') | |
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams): | |
+ for c in callback_map['callbacks_cfg_after_cfg']: | |
+ try: | |
+ c.callback(params) | |
+ except Exception: | |
+ report_exception(c, 'cfg_after_cfg_callback') | |
def before_component_callback(component, **kwargs): | |
for c in callback_map['callbacks_before_component']: | |
@@ -331,13 +351,18 @@ def on_cfg_denoised(callback): | |
""" | |
add_callback(callback_map['callbacks_cfg_denoised'], callback) | |
+def on_cfg_after_cfg(callback): | |
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations has completed. | |
+ The callback is called with one argument: | |
+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. | |
+ """ | |
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback) | |
def on_before_component(callback): | |
"""register a function to be called before a component is created. | |
The callback is called with arguments: | |
- component - gradio component that is about to be created. | |
- **kwargs - args to gradio.components.IOComponent.__init__ function | |
- | |
Use elem_id/label fields of kwargs to figure out which component it is. | |
This can be useful to inject your own components somewhere in the middle of vanilla UI. | |
""" | |
@@ -369,11 +394,9 @@ def on_infotext_pasted(callback): | |
def on_script_unloaded(callback): | |
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that | |
the script did should be reverted here""" | |
- | |
add_callback(callback_map['callbacks_script_unloaded'], callback) | |
def on_before_ui(callback): | |
"""register a function to be called before the UI is created.""" | |
- | |
- add_callback(callback_map['callbacks_before_ui'], callback) | |
+ add_callback(callback_map['callbacks_before_ui'], callback) | |
\ No newline at end of file | |
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py | |
index 0fc9f456..f90538a9 100644 | |
--- a/modules/sd_samplers_kdiffusion.py | |
+++ b/modules/sd_samplers_kdiffusion.py | |
@@ -9,6 +9,7 @@ from modules.shared import opts, state | |
import modules.shared as shared | |
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback | |
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback | |
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback | |
samplers_k_diffusion = [ | |
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), | |
@@ -161,7 +162,7 @@ class CFGDenoiser(torch.nn.Module): | |
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) | |
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be | |
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) | |
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) | |
cfg_denoised_callback(denoised_params) | |
devices.test_for_nans(x_out, "unet") | |
@@ -181,6 +182,12 @@ class CFGDenoiser(torch.nn.Module): | |
if self.mask is not None: | |
denoised = self.init_latent * self.mask + self.nmask * denoised | |
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) | |
+ cfg_after_cfg_callback(after_cfg_callback_params) | |
+ | |
+ if after_cfg_callback_params.output_altered: | |
+ denoised = after_cfg_callback_params.x | |
+ | |
self.step += 1 | |
return denoised | |
@@ -383,4 +390,3 @@ class KDiffusionSampler: | |
}, disable=False, callback=self.callback_state, **extra_params_kwargs)) | |
return samples | |
- |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment