Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created November 6, 2023 15:20
Show Gist options
  • Save laksjdjf/126d5bf6092d943a573197a49db30859 to your computer and use it in GitHub Desktop.
Save laksjdjf/126d5bf6092d943a573197a49db30859 to your computer and use it in GitHub Desktop.
# ref:https://github.com/v0xie/sd-webui-cads
'''
1. put this file in ComfyUI/custom_nodes
2. load node from <loader>
'''
import torch
import numpy as np
import copy
class CADS:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"rescale": (["true", "false"], ),
"apply_negative_prompt": (["true", "false"], ),
"t1": ("FLOAT", {
"default": 0.6,
"min": 0.0, # Minimum value
"max": 1.0, # Maximum value
"step": 0.01 # Slider's step
}),
"t2": ("FLOAT", {
"default": 0.9,
"min": 0.0, # Minimum value
"max": 1.0, # Maximum value
"step": 0.01 # Slider's step
}),
"noise_scale": ("FLOAT", {
"default": 0.25,
"min": 0.0, # Minimum value
"max": 1.0, # Maximum value
"step": 0.01 # Slider's step
}),
"mixing_factor": ("FLOAT", {
"default": 1.0,
"min": 0.0, # Minimum value
"max": 1.0, # Maximum value
"step": 0.01 # Slider's step
}),
"seed": ("INT", {
"default": 0,
"min": 0,
"max": 1000000000,
"step": 1,
"display": "number"
}),
"apply_negative_prompt": (["true", "false"], ),
},
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "apply"
CATEGORY = "loaders"
def apply(self, model, rescale, t1, t2, noise_scale, mixing_factor, seed, apply_negative_prompt):
self.rescale = rescale == "true"
self.t1 = t1
self.t2 = t2
self.noise_scale = noise_scale
self.mixing_factor = mixing_factor
self.seed = seed
self.apply_negative_prompt = apply_negative_prompt == "true"
new_model = model.clone()
def apply_model(model_function, kwargs):
x = kwargs["input"]
t = kwargs["timestep"]
cond_or_uncond = kwargs["c"]["transformer_options"]["cond_or_uncond"]
currernt_t = new_model.model.model_sampling.timestep(t)[0].item() / 1000
gamma = self.cads_linear_schedule(currernt_t, self.t1, self.t2)
cross_attns = list(kwargs["c"]["c_crossattn"].chunk(len(cond_or_uncond)))
for i, c in enumerate(cond_or_uncond):
if c == 0:
cross_attns[i] = self.add_noise(cross_attns[i], gamma, self.noise_scale, self.mixing_factor, self.rescale)
if c == 1 and self.apply_negative_prompt:
cross_attns[i] = self.add_noise(cross_attns[i], gamma, self.noise_scale, self.mixing_factor, self.rescale)
new_cross_attn = torch.cat(cross_attns, dim=0)
new_c = copy.copy(kwargs["c"])
new_c["c_crossattn"] = new_cross_attn
return model_function(x, t, **new_c)
new_model.set_model_unet_function_wrapper(apply_model)
return (new_model, )
def cads_linear_schedule(self, t, tau1, tau2):
""" CADS annealing schedule function """
if t <= tau1:
return 1.0
if t>= tau2:
return 0.0
gamma = (tau2-t)/(tau2-tau1)
return gamma
def add_noise(self, y, gamma, noise_scale, psi, rescale=False):
""" CADS adding noise to the condition
Arguments:
y: Input conditioning
gamma: Noise level w.r.t t
noise_scale (float): Noise scale
psi (float): Rescaling factor
rescale (bool): Rescale the condition
"""
y = np.sqrt(gamma) * y + noise_scale * np.sqrt(1-gamma) * self.randn_like_with_seed(y, self.seed)
if rescale:
y_mean, y_std = torch.mean(y), torch.std(y)
y_scaled = (y - torch.mean(y)) / torch.std(y) * y_std + y_mean
if not torch.isnan(y_scaled).any():
y = psi * y_scaled + (1 - psi) * y
else:
UserWarning("NaN encountered in rescaling")
return y
def randn_like_with_seed(self, x, seed):
""" Generate random tensor with the same shape as x """
rng_state = torch.get_rng_state()
rng_state_cuda = torch.cuda.get_rng_state()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
y = torch.randn_like(x)
torch.set_rng_state(rng_state)
torch.cuda.set_rng_state(rng_state_cuda)
return y
NODE_CLASS_MAPPINGS = {
"CADS": CADS,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CADS": "Apply CADS",
}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment