-
-
Save mcmonkey4eva/25f2b4d95a5ede0934885d74c0669293 to your computer and use it in GitHub Desktop.
CacheSampley
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import latent_preview | |
import comfy | |
def slerp(val, low, high): | |
low_norm = low / torch.norm(low, dim=1, keepdim=True) | |
high_norm = high / torch.norm(high, dim=1, keepdim=True) | |
dot = (low_norm * high_norm).sum(1) | |
if dot.mean() > 0.9995: | |
return low * val + high * (1 - val) | |
omega = torch.acos(dot) | |
so = torch.sin(omega) | |
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high | |
return res | |
def swarm_partial_noise(seed, latent_image): | |
generator = torch.manual_seed(seed) | |
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") | |
def swarm_fixed_noise(seed, latent_image, var_seed, var_seed_strength): | |
noises = [] | |
for i in range(latent_image.size()[0]): | |
if var_seed_strength > 0: | |
noise = swarm_partial_noise(seed, latent_image[i]) | |
var_noise = swarm_partial_noise(var_seed + i, latent_image[i]) | |
noise = slerp(var_seed_strength, noise, var_noise) | |
else: | |
noise = swarm_partial_noise(seed + i, latent_image[i]) | |
noises.append(noise) | |
return torch.stack(noises, axis=0) | |
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, timestep_embedding, forward_timestep_embed | |
class CacheySampley: | |
@classmethod | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL",), | |
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), | |
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), | |
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.5, "round": 0.001}), | |
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), | |
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), | |
"positive": ("CONDITIONING", ), | |
"negative": ("CONDITIONING", ), | |
"latent_image": ("LATENT", ), | |
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), | |
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), | |
"var_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), | |
"var_seed_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.05, "round": 0.001}), | |
"add_noise": (["enable", "disable"], ), | |
"return_with_leftover_noise": (["disable", "enable"], ), | |
"cache_in_start": ("INT", {"default": 600, "min": 0, "max": 1000, "step": 25, "round": 1}), | |
"cache_in_start2": ("INT", {"default": 400, "min": 0, "max": 1000, "step": 25, "round": 1}), | |
"cache_mid_start": ("INT", {"default": 800, "min": 0, "max": 1000, "step": 25, "round": 1}), | |
"cache_out_start": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 25, "round": 1}), | |
"cache_in_block": ("INT", {"default": 6, "min": 0, "max": 8}), | |
"cache_in_block2": ("INT", {"default": 4, "min": 0, "max": 8}), | |
"cache_out_block": ("INT", {"default": 3, "min": 0, "max": 8}), | |
"cache_disable_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 25, "round": 1}), | |
"full_run_step_rate": ("INT", {"default": 1000, "min": 0, "max": 1000, "step": 25, "round": 1}), | |
} | |
} | |
CATEGORY = "hacks" | |
RETURN_TYPES = ("LATENT",) | |
FUNCTION = "sample" | |
def sample(self, model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, var_seed, var_seed_strength, add_noise, return_with_leftover_noise, | |
cache_in_start, cache_in_start2, cache_mid_start, cache_out_start, cache_in_block, cache_in_block2, cache_out_block, cache_disable_step, full_run_step_rate): | |
CACHE_LAST = {'ts': 1000} | |
def hijacked_unet_forward(unet, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): | |
transformer_options["original_shape"] = list(x.shape) | |
transformer_options["current_index"] = 0 | |
transformer_patches = transformer_options.get("patches", {}) | |
assert (y is not None) == ( | |
unet.num_classes is not None | |
), "must specify y if and only if the model is class-conditional" | |
hs = [] | |
t_emb = timestep_embedding(timesteps, unet.model_channels, repeat_only=False).to(unet.dtype) | |
emb = unet.time_embed(t_emb) | |
if unet.num_classes is not None: | |
assert y.shape[0] == x.shape[0] | |
emb = emb + unet.label_emb(y) | |
h = x.type(unet.dtype) | |
timestep_index = float(timesteps[0]) | |
do_full = False | |
if timestep_index > cache_in_start and timestep_index > cache_mid_start: | |
CACHE_LAST['ts'] = timestep_index | |
elif timestep_index < CACHE_LAST['ts'] - full_run_step_rate: | |
CACHE_LAST['ts'] = timestep_index | |
do_full = True | |
for id, module in enumerate(unet.input_blocks): | |
if id > cache_in_block and timestep_index < cache_in_start and timestep_index > cache_disable_step and not do_full: | |
h = CACHE_LAST[f'in{id}'] | |
elif id > cache_in_block2 and timestep_index < cache_in_start2 and timestep_index > cache_disable_step and not do_full: | |
h = CACHE_LAST[f'in{id}'] | |
else: | |
transformer_options["block"] = ("input", id) | |
h = forward_timestep_embed(module, h, emb, context, transformer_options) | |
if control is not None and 'input' in control and len(control['input']) > 0: | |
ctrl = control['input'].pop() | |
if ctrl is not None: | |
h += ctrl | |
CACHE_LAST[f'in{id}'] = h | |
#print(f"in {id} is {h.mean()}") | |
hs.append(h) | |
if timestep_index < cache_mid_start and timestep_index > cache_disable_step and not do_full: | |
h = CACHE_LAST['mid'] | |
else: | |
transformer_options["block"] = ("middle", 0) | |
h = forward_timestep_embed(unet.middle_block, h, emb, context, transformer_options) | |
if control is not None and 'middle' in control and len(control['middle']) > 0: | |
ctrl = control['middle'].pop() | |
if ctrl is not None: | |
h += ctrl | |
CACHE_LAST['mid'] = h | |
#print(f"mid is {h.mean()}") | |
for id, module in enumerate(unet.output_blocks): | |
hsp = hs.pop() | |
if id < cache_out_block and timestep_index < cache_out_start and timestep_index > cache_disable_step and not do_full: | |
h = CACHE_LAST[f'out{id}'] | |
else: | |
transformer_options["block"] = ("output", id) | |
if control is not None and 'output' in control and len(control['output']) > 0: | |
ctrl = control['output'].pop() | |
if ctrl is not None: | |
hsp += ctrl | |
if "output_block_patch" in transformer_patches: | |
patch = transformer_patches["output_block_patch"] | |
for p in patch: | |
h, hsp = p(h, hsp, transformer_options) | |
h = torch.cat([h, hsp], dim=1) | |
del hsp | |
if len(hs) > 0: | |
output_shape = hs[-1].shape | |
else: | |
output_shape = None | |
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) | |
CACHE_LAST[f'out{id}'] = h | |
#print(f"out {id} is {h.mean()}") | |
h = h.type(x.dtype) | |
if unet.predict_codebook_ids: | |
return unet.id_predictor(h) | |
else: | |
return unet.out(h) | |
device = comfy.model_management.get_torch_device() | |
latent_samples = latent_image["samples"] | |
disable_noise = add_noise == "disable" | |
orig_func = getattr(UNetModel, "forward") | |
setattr(UNetModel, "forward", hijacked_unet_forward) | |
try: | |
if disable_noise: | |
noise = torch.zeros(latent_samples.size(), dtype=latent_samples.dtype, layout=latent_samples.layout, device="cpu") | |
else: | |
noise = swarm_fixed_noise(noise_seed, latent_samples, var_seed, var_seed_strength) | |
noise_mask = None | |
if "noise_mask" in latent_image: | |
noise_mask = latent_image["noise_mask"] | |
previewer = latent_preview.get_previewer(device, model.model.latent_format) | |
pbar = comfy.utils.ProgressBar(steps) | |
def callback(step, x0, x, total_steps): | |
preview_bytes = None | |
if previewer: | |
preview_bytes = previewer.decode_latent_to_preview_image("JPEG", x0) | |
pbar.update_absolute(step + 1, total_steps, preview_bytes) | |
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_samples, | |
denoise=1.0, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, | |
force_full_denoise=return_with_leftover_noise == "disable", noise_mask=noise_mask, callback=callback, seed=noise_seed) | |
out = latent_image.copy() | |
out["samples"] = samples | |
return (out, ) | |
finally: | |
setattr(UNetModel, "forward", orig_func) | |
NODE_CLASS_MAPPINGS = { | |
"CacheySampley": CacheySampley, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Great work~
How to support controlnet?