Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Last active May 2, 2024 09:09
Show Gist options
  • Star 15 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save laksjdjf/435c512bc19636e9c9af4ee7bea9eb86 to your computer and use it in GitHub Desktop.
Save laksjdjf/435c512bc19636e9c9af4ee7bea9eb86 to your computer and use it in GitHub Desktop.
'''
https://arxiv.org/abs/2312.00858
1. put this file in ComfyUI/custom_nodes
2. load node from <loaders>
start_step, end_step: apply this method when the timestep is between start_step and end_step
cache_interval: interval of caching (1 means no caching)
cache_depth: depth of caching
'''
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, timestep_embedding, th, apply_control
class DeepCache:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"cache_interval": ("INT", {
"default": 5,
"min": 1,
"max": 1000,
"step": 1,
"display": "number"
}),
"cache_depth": ("INT", {
"default": 3,
"min": 0,
"max": 12,
"step": 1,
"display": "number"
}),
"start_step": ("INT", {
"default": 0,
"min": 0,
"max": 1000,
"step": 1,
"display": "number"
}),
"end_step": ("INT", {
"default": 1000,
"min": 0,
"max": 1000,
"step": 0.1,
}),
},
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "apply"
CATEGORY = "loaders"
def apply(self, model, cache_interval, cache_depth, start_step, end_step):
new_model = model.clone()
current_t = -1
current_step = -1
cache_h = None
def apply_model(model_function, kwargs):
nonlocal current_t, current_step, cache_h
xa = kwargs["input"]
t = kwargs["timestep"]
c_concat = kwargs["c"].get("c_concat", None)
c_crossattn = kwargs["c"].get("c_crossattn", None)
y = kwargs["c"].get("y", None)
control = kwargs["c"].get("control", None)
transformer_options = kwargs["c"].get("transformer_options", None)
# https://github.com/comfyanonymous/ComfyUI/blob/629e4c552cc30a75d2756cbff8095640af3af163/comfy/model_base.py#L51-L69
sigma = t
xc = new_model.model.model_sampling.calculate_input(sigma, xa)
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)
context = c_crossattn
dtype = new_model.model.get_dtype()
xc = xc.to(dtype)
t = new_model.model.model_sampling.timestep(t).float()
context = context.to(dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "to"):
extra = extra.to(dtype)
extra_conds[o] = extra
x = xc
timesteps = t
y = None if y is None else y.to(dtype)
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
transformer_patches = transformer_options.get("patches", {})
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
unet = new_model.model.diffusion_model
# unet次回実行はtimestepが上がってると仮定・・Refiner等でエラーが起きるかも
if t[0].item() > current_t:
current_step = -1
current_t = t[0].item()
apply = 1000 - end_step <= current_t <= 1000 - start_step # tは999->0
if apply:
current_step += 1
else:
current_step = -1
current_t = t[0].item()
# https://github.com/comfyanonymous/ComfyUI/blob/629e4c552cc30a75d2756cbff8095640af3af163/comfy/ldm/modules/diffusionmodules/openaimodel.py#L598-L659
assert (y is not None) == (
unet.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, unet.model_channels, repeat_only=False).to(unet.dtype)
emb = unet.time_embed(t_emb)
if unet.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + unet.label_emb(y)
h = x.type(unet.dtype)
for id, module in enumerate(unet.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options)
h = apply_control(h, control, 'input')
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
h = p(h, transformer_options)
hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)
if id == cache_depth and apply:
if not current_step % cache_interval == 0:
break # cache位置以降はスキップ
if current_step % cache_interval == 0 or not apply:
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(unet.middle_block, h, emb, context, transformer_options)
h = apply_control(h, control, 'middle')
for id, module in enumerate(unet.output_blocks):
if id < len(unet.output_blocks) - cache_depth - 1 and apply:
if not current_step % cache_interval == 0:
continue # cache位置以前はスキップ
if id == len(unet.output_blocks) - cache_depth -1 and apply:
if current_step % cache_interval == 0:
cache_h = h # cache
else:
h = cache_h # load cache
transformer_options["block"] = ("output", id)
hsp = hs.pop()
hsp = apply_control(hsp, control, 'output')
if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"]
for p in patch:
h, hsp = p(h, hsp, transformer_options)
h = th.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
h = h.type(x.dtype)
if unet.predict_codebook_ids:
model_output = unet.id_predictor(h)
else:
model_output = unet.out(h)
return new_model.model.model_sampling.calculate_denoised(sigma, model_output, xa)
new_model.set_model_unet_function_wrapper(apply_model)
return (new_model, )
NODE_CLASS_MAPPINGS = {
"DeepCache": DeepCache,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DeepCache": "Deep Cache",
}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
@jav12z
Copy link

jav12z commented Dec 27, 2023

Hi, many thanks for this. Is there any possibility of having also the implementations for SVD and SDXL, thanks again and happy holidays.

@Piscabo
Copy link

Piscabo commented Dec 28, 2023

I love this script, and would like to know if it is possible to integrate this. I'm a rookie so not able to do it myself.

import multiprocessing

Define a function to perform the computation for a single iteration

def process_iteration(args):
# Perform the computation based on the given arguments
# ...

Modify the relevant part of your script where the loop or computation occurs

def parallel_computation():
# Determine the number of iterations or tasks
num_iterations = 100 # Replace with the actual number

# Define the number of processes to use
num_processes = multiprocessing.cpu_count()  # Use available CPU cores or specify a number

# Create a multiprocessing Pool
with multiprocessing.Pool(processes=num_processes) as pool:
    # Generate arguments for each iteration/task (modify based on your code structure)
    iterations_args = [(arg1, arg2, ...) for i in range(num_iterations)]
    # ^ Replace (arg1, arg2, ...) with the actual arguments needed for each iteration

    # Map the function to the arguments for parallel execution
    pool.map(process_iteration, iterations_args)

Call the function to execute the parallel computation

if name == "main":
parallel_computation()

@thekarmakazi
Copy link

thanks for this man, any change the xl or video versions could be implemented? appreciate the share!

@laksjdjf
Copy link
Author

It should be compatible with sdxl.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment