Skip to content

Instantly share code, notes, and snippets.

@mcmonkey4eva
Last active December 12, 2023 16:29
Show Gist options
  • Save mcmonkey4eva/25f2b4d95a5ede0934885d74c0669293 to your computer and use it in GitHub Desktop.
Save mcmonkey4eva/25f2b4d95a5ede0934885d74c0669293 to your computer and use it in GitHub Desktop.
CacheSampley
import torch
import latent_preview
import comfy
def slerp(val, low, high):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
dot = (low_norm * high_norm).sum(1)
if dot.mean() > 0.9995:
return low * val + high * (1 - val)
omega = torch.acos(dot)
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
def swarm_partial_noise(seed, latent_image):
generator = torch.manual_seed(seed)
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
def swarm_fixed_noise(seed, latent_image, var_seed, var_seed_strength):
noises = []
for i in range(latent_image.size()[0]):
if var_seed_strength > 0:
noise = swarm_partial_noise(seed, latent_image[i])
var_noise = swarm_partial_noise(var_seed + i, latent_image[i])
noise = slerp(var_seed_strength, noise, var_noise)
else:
noise = swarm_partial_noise(seed + i, latent_image[i])
noises.append(noise)
return torch.stack(noises, axis=0)
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, timestep_embedding, forward_timestep_embed
class CacheySampley:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.5, "round": 0.001}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ),
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
"var_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"var_seed_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.05, "round": 0.001}),
"add_noise": (["enable", "disable"], ),
"return_with_leftover_noise": (["disable", "enable"], ),
"cache_in_start": ("INT", {"default": 600, "min": 0, "max": 1000, "step": 25, "round": 1}),
"cache_in_start2": ("INT", {"default": 400, "min": 0, "max": 1000, "step": 25, "round": 1}),
"cache_mid_start": ("INT", {"default": 800, "min": 0, "max": 1000, "step": 25, "round": 1}),
"cache_out_start": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 25, "round": 1}),
"cache_in_block": ("INT", {"default": 6, "min": 0, "max": 8}),
"cache_in_block2": ("INT", {"default": 4, "min": 0, "max": 8}),
"cache_out_block": ("INT", {"default": 3, "min": 0, "max": 8}),
"cache_disable_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 25, "round": 1}),
"full_run_step_rate": ("INT", {"default": 1000, "min": 0, "max": 1000, "step": 25, "round": 1}),
}
}
CATEGORY = "hacks"
RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"
def sample(self, model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, var_seed, var_seed_strength, add_noise, return_with_leftover_noise,
cache_in_start, cache_in_start2, cache_mid_start, cache_out_start, cache_in_block, cache_in_block2, cache_out_block, cache_disable_step, full_run_step_rate):
CACHE_LAST = {'ts': 1000}
def hijacked_unet_forward(unet, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
transformer_patches = transformer_options.get("patches", {})
assert (y is not None) == (
unet.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, unet.model_channels, repeat_only=False).to(unet.dtype)
emb = unet.time_embed(t_emb)
if unet.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + unet.label_emb(y)
h = x.type(unet.dtype)
timestep_index = float(timesteps[0])
do_full = False
if timestep_index > cache_in_start and timestep_index > cache_mid_start:
CACHE_LAST['ts'] = timestep_index
elif timestep_index < CACHE_LAST['ts'] - full_run_step_rate:
CACHE_LAST['ts'] = timestep_index
do_full = True
for id, module in enumerate(unet.input_blocks):
if id > cache_in_block and timestep_index < cache_in_start and timestep_index > cache_disable_step and not do_full:
h = CACHE_LAST[f'in{id}']
elif id > cache_in_block2 and timestep_index < cache_in_start2 and timestep_index > cache_disable_step and not do_full:
h = CACHE_LAST[f'in{id}']
else:
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
CACHE_LAST[f'in{id}'] = h
#print(f"in {id} is {h.mean()}")
hs.append(h)
if timestep_index < cache_mid_start and timestep_index > cache_disable_step and not do_full:
h = CACHE_LAST['mid']
else:
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(unet.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0:
ctrl = control['middle'].pop()
if ctrl is not None:
h += ctrl
CACHE_LAST['mid'] = h
#print(f"mid is {h.mean()}")
for id, module in enumerate(unet.output_blocks):
hsp = hs.pop()
if id < cache_out_block and timestep_index < cache_out_start and timestep_index > cache_disable_step and not do_full:
h = CACHE_LAST[f'out{id}']
else:
transformer_options["block"] = ("output", id)
if control is not None and 'output' in control and len(control['output']) > 0:
ctrl = control['output'].pop()
if ctrl is not None:
hsp += ctrl
if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"]
for p in patch:
h, hsp = p(h, hsp, transformer_options)
h = torch.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
CACHE_LAST[f'out{id}'] = h
#print(f"out {id} is {h.mean()}")
h = h.type(x.dtype)
if unet.predict_codebook_ids:
return unet.id_predictor(h)
else:
return unet.out(h)
device = comfy.model_management.get_torch_device()
latent_samples = latent_image["samples"]
disable_noise = add_noise == "disable"
orig_func = getattr(UNetModel, "forward")
setattr(UNetModel, "forward", hijacked_unet_forward)
try:
if disable_noise:
noise = torch.zeros(latent_samples.size(), dtype=latent_samples.dtype, layout=latent_samples.layout, device="cpu")
else:
noise = swarm_fixed_noise(noise_seed, latent_samples, var_seed, var_seed_strength)
noise_mask = None
if "noise_mask" in latent_image:
noise_mask = latent_image["noise_mask"]
previewer = latent_preview.get_previewer(device, model.model.latent_format)
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image("JPEG", x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_samples,
denoise=1.0, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step,
force_full_denoise=return_with_leftover_noise == "disable", noise_mask=noise_mask, callback=callback, seed=noise_seed)
out = latent_image.copy()
out["samples"] = samples
return (out, )
finally:
setattr(UNetModel, "forward", orig_func)
NODE_CLASS_MAPPINGS = {
"CacheySampley": CacheySampley,
}
@fishelegs
Copy link

Great work~

How to support controlnet?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment