Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save papuSpartan/f97c4b423352bf26e628910970b9e555 to your computer and use it in GitHub Desktop.
Save papuSpartan/f97c4b423352bf26e628910970b9e555 to your computer and use it in GitHub Desktop.
SAG patch for auto1111(dev) 335428c2c8139dfe07ba096a6defa75036660244
From cef07987cc6e1776d3c3d88691ce94c8fc2c3a0c Mon Sep 17 00:00:00 2001
From: papuSpartan <30642826+papuSpartan@users.noreply.github.com>
Date: Wed, 3 May 2023 20:06:27 -0500
Subject: [PATCH] patch in callbacks for SAG to dev
---
modules/script_callbacks.py | 31 ++++++++++++++++++++++++++++++-
modules/sd_samplers_kdiffusion.py | 9 ++++++++-
2 files changed, 38 insertions(+), 2 deletions(-)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 17109732..e0436a00 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -53,7 +53,7 @@ class CFGDenoiserParams:
class CFGDenoisedParams:
- def __init__(self, x, sampling_step, total_sampling_steps):
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
self.x = x
"""Latent image representation in the process of being denoised"""
@@ -63,6 +63,22 @@ class CFGDenoisedParams:
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
+ self.inner_model = inner_model
+ """Inner model reference that is being used for denoising"""
+
+
+
+class AfterCFGCallbackParams:
+ def __init__(self, x, sampling_step, total_sampling_steps):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+ self.output_altered = False
+ """A flag for CFGDenoiser that indicates whether the output has been altered by the callback"""
+
class UiTrainTabParams:
def __init__(self, txt2img_preview_params):
@@ -87,6 +103,7 @@ callback_map = dict(
callbacks_image_saved=[],
callbacks_cfg_denoiser=[],
callbacks_cfg_denoised=[],
+ callbacks_cfg_after_cfg=[],
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
@@ -185,6 +202,12 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
except Exception:
report_exception(c, 'cfg_denoised_callback')
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
+ for c in callback_map['callbacks_cfg_after_cfg']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'cfg_after_cfg_callback')
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
@@ -331,6 +354,12 @@ def on_cfg_denoised(callback):
"""
add_callback(callback_map['callbacks_cfg_denoised'], callback)
+def on_cfg_after_cfg(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations has completed.
+ The callback is called with one argument:
+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
+ """
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
def on_before_component(callback):
"""register a function to be called before a component is created.
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index eb98e599..8e63859b 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -9,6 +9,7 @@ from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
@@ -161,7 +162,7 @@ class CFGDenoiser(torch.nn.Module):
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
@@ -181,7 +182,13 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+
+ if after_cfg_callback_params.output_altered:
+ denoised = after_cfg_callback_params.x
self.step += 1
+
return denoised
--
2.40.0.windows.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment