Skip to content

Instantly share code, notes, and snippets.

@torridgristle
Created October 12, 2022 15:39
Show Gist options
  • Save torridgristle/ca942c2e1c31ac31111d31931ed1dfbb to your computer and use it in GitHub Desktop.
Save torridgristle/ca942c2e1c31ac31111d31931ed1dfbb to your computer and use it in GitHub Desktop.
Stable Diffusion CFGDenoiser with slew limiting and frequency splitting for detail preservation as an option.
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
1class CFGDenoiserSlew(nn.Module):
'''
Clamps the maximum change each step can have.
"limit" is the clamp bounds. 0.4-0.8 seem good, 1.6 and 3.2 have very little difference and might represent the upper bound of values.
"blur" is the radius of a gaussian blur used to split the limited output with the original output in an attempt to preserve detail and color.
"last_step_is_blur" if true will compare the model output to the blur-split output rather than just the limited output, can look nicer.
'''
def __init__(self, model, limit = 0.2, blur = 5, last_step_is_blur = True):
super().__init__()
self.inner_model = model
self.last_sigma = 0.0 # For keeping track of when the sampling cycle restarts for a new image
self.last_step = None # For keeping the last step for measuring change between steps
self.limit = limit # The clamp bounds
self.blur = blur # Radius of the blur for freq splitting and merging limited and non-limited outputs
self.last_step_is_blur = last_step_is_blur # Compare outputs to the freq split output instead of the plain limited output
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
result_clean = uncond + (cond - uncond) * cond_scale
if sigma > self.last_sigma:
self.last_step = None
self.last_sigma = sigma
if self.last_step != None:
diff = result_clean - self.last_step
result = diff.clamp(-1 * self.limit, self.limit) + self.last_step
if self.last_step_is_blur == False:
self.last_step = result # Pre-blur
if self.blur > 1:
result = TF.gaussian_blur(result, self.blur)
result_clean_hi = result_clean - TF.gaussian_blur(result_clean, self.blur)
result = result + result_clean_hi
if self.last_step_is_blur == True:
self.last_step = result # Post-blur
del result_clean_hi
del diff, x_in, sigma_in, cond_in, uncond, cond, result_clean
else:
result = result_clean
self.last_step = result
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment