Skip to content

Instantly share code, notes, and snippets.

@aria1th
Created December 8, 2023 19:35
Show Gist options
  • Save aria1th/04bb78207daeee1f3d0800dc422e6254 to your computer and use it in GitHub Desktop.
Save aria1th/04bb78207daeee1f3d0800dc422e6254 to your computer and use it in GitHub Desktop.
DeepCacheStandAlone
class DeepCacheStandAlone:
"""
@source https://github.com/horseee/DeepCache
Standalone version of DeepCache, which can be used without the DeepCacheScript.
For multiple switching UNets, you can specify cache_type to use different caches.
Code Snippet:
```python
# U-Net Encoder
for i, module in enumerate(self.input_blocks):
if session.should_skip_unet_in(i):
continue
h = forward_timestep_embed(module, h, emb, context)
hs.append(h)
if not session.should_skip_unet_middle():
h = forward_timestep_embed(self.middle_block, h, emb, context)
# U-Net Decoder
total_out_blocks = len(self.output_blocks)
for idx, module in enumerate(self.output_blocks):
if session.cond_skip_unet_out(idx, total_out_blocks):
continue
if session.should_load_unet_out(idx, total_out_blocks):
h = session.get_cache()
else:
hsp = hs.pop()
h = torch.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, output_shape=output_shape)
if session.should_cache_unet_out(idx, total_out_blocks):
session.put_cache(h)
"""
def __init__(self, params:dict = None, enable:bool = True) -> None:
if params is None:
params = {}
from collections import defaultdict
self.cache_in_level = params.get('cache_in_level', 0)
self.cache_enable_step = params.get('cache_enable_step', 8)
self.full_run_step_rate = params.get('full_run_step_rate', 3)
self.caches = defaultdict(lambda: {"timestep": set()})
self.timestep_accumulator = 0
self.enable = enable
self.debug_mode = True
self.skip_statistics = {} # will log in_true / in_false / out_true / out_false / middle_true / middle_false
self.total_accumulator = 0
self.accumulate_every = 2 # controlnet forward every 2 steps
def report(self):
"""
Reports the cache statistics.
"""
print("DeepCache : report")
# total, in, out, middle success rate and count
total = sum(self.skip_statistics.values())
in_sum = self.skip_statistics.get('in_true', 0) + self.skip_statistics.get('in_false', 0)
out_sum = self.skip_statistics.get('out_true', 0) + self.skip_statistics.get('out_false', 0)
middle_sum = self.skip_statistics.get('middle_true', 0) + self.skip_statistics.get('middle_false', 0)
print(f"DeepCache : total {total} in {in_sum} out {out_sum} middle {middle_sum}")
print("Success rates")
print(f"IN : {self.skip_statistics.get('in_true', 0) / in_sum if in_sum > 0 else 0}, total {in_sum}")
print(f"OUT : {self.skip_statistics.get('out_true', 0) / out_sum if out_sum > 0 else 0}, total {out_sum}")
print(f"MIDDLE : {self.skip_statistics.get('middle_true', 0) / middle_sum if middle_sum > 0 else 0}, total {middle_sum}")
def should_activate(self):
"""
Returns if DeepCache should be activated.
"""
if not self.enable:
self.debug_print("DeepCache : Disabled")
return False
if self.full_run_step_rate < 1:
self.debug_print(f"DeepCache : full_run_step_rate {self.full_run_step_rate} < 1")
return False
if not self.timestep_accumulator > self.cache_enable_step:
self.debug_print(f"DeepCache : timestep {self.timestep_accumulator} <= cache_enable_step {self.cache_enable_step}")
return False
return True
def can_use_cache(self, cache_type="last"):
"""
Returns if DeepCache can be used.
"""
return self.get_cache(cache_type) is not None and self.timestep_accumulator % self.full_run_step_rate != 0 and self.timestep_accumulator > self.cache_enable_step
def should_cache_unet_out(self, index, total_out_blocks:int) -> bool:
"""
Returns if current in block should be cached.
The function should be called exclusively with should_load_unet_out branch, to execute forward+cache / load+continue.
Usage : if self.should_cache_unet_out(index): h = model.forward();self.put_cache(h)
"""
if not self.should_activate() or self.timestep_accumulator % self.full_run_step_rate != 0:
return False
return index == total_out_blocks - self.cache_in_level - 1
def should_load_unet_out(self, index, total_out_blocks:int, cache_type="last") -> bool:
"""
Returns if current out block should be loaded.
The function should be called before the forward_timestep_embed call, to replace the input with the cached tensor.
The function should be called before should_cache_unet_out call.
Usage : if should_load_unet_out(index, len(unet.output_blocks)): h = self.get_cache()
"""
return self.can_use_cache(cache_type) and index == total_out_blocks - self.cache_in_level - 1
def cond_skip_unet_in(self, index, cache_type="last") -> bool:
"""
Returns if current in block should be skipped.
The function should be called before the forward_timestep_embed call.
Usage : if self.cond_skip_unet_in(index): continue
"""
result = self.can_use_cache(cache_type) and index > self.cache_in_level
if result:
self.skip_statistics['in_true'] = self.skip_statistics.get('in_true', 0) + 1
else:
self.skip_statistics['in_false'] = self.skip_statistics.get('in_false', 0) + 1
def cond_skip_unet_out(self, index, total_out_blocks, cache_type="last") -> bool:
"""
Returns if current out block should be skipped.
The function should be called before should_load_net_out and should_cache_unet_out call.
Usage : if self.cond_skip_unet_out(index, len(unet.output_blocks)): continue
"""
result = self.can_use_cache(cache_type) and index < total_out_blocks - self.cache_in_level - 1
if result:
self.skip_statistics['out_true'] = self.skip_statistics.get('out_true', 0) + 1
else:
self.skip_statistics['out_false'] = self.skip_statistics.get('out_false', 0) + 1
def cond_skip_unet_middle(self, cache_type="last") -> bool:
"""
Returns if middle block should be skipped.
Usage : if self.cond_skip_unet_middle(): continue
"""
timestep = self.timestep_accumulator
result = self.can_use_cache(cache_type)
if result:
self.skip_statistics['middle_true'] = self.skip_statistics.get('middle_true', 0) + 1
else:
self.skip_statistics['middle_false'] = self.skip_statistics.get('middle_false', 0) + 1
def put_cache(self, h, cache_type="last"):
"""
Registers cache
Usage : if self.should_cache_unet_out(index): self.put_cache(h)
"""
timestep = self.timestep_accumulator
target_cache = self.caches[cache_type]
target_cache["timestep"].add(timestep)
assert h is not None, "Cannot cache None"
# maybe move to cpu and load later for low vram?
target_cache["last"] = h
for _i in range(self.full_run_step_rate):
# register for each step too
target_cache[f"timestep_{timestep + _i}"] = h
self.debug_print(f"DeepCache : put cache for timestep {timestep}")
def debug_print(self, *args, **kwargs):
if self.debug_mode:
print(*args, **kwargs)
def get_cache(self, cache_type="last"):
"""
Returns the cached tensor for the given timestep and cache key.
Usage : if self.should_load_unet_out(index, len(unet.output_blocks)): h = self.get_cache()
"""
target_cache = self.caches[cache_type]
if not self.should_activate():
self.debug_print("DeepCache : Disabled")
return None
current_timestep = self.timestep_accumulator
if current_timestep < self.cache_enable_step:
self.debug_print(f"DeepCache : timestep {current_timestep} < cache_enable_step {self.cache_enable_step}")
return None
elif self.full_run_step_rate < 1:
self.debug_print(f"DeepCache : full_run_step_rate {self.full_run_step_rate} < 1")
return None
elif current_timestep % self.full_run_step_rate != 0:
if f"timestep_{current_timestep}" in target_cache:
target_cache["last"] = target_cache[f"timestep_{current_timestep}"] # update last
self.debug_print(f"DeepCache : load cache for timestep {current_timestep}")
return target_cache[f"timestep_{current_timestep}"]
self.debug_print(f"DeepCache : cache for timestep {current_timestep} not found")
self.debug_print(f"DeepCache : Step is divisible, running full forward")
return None
def increment_timestep(self):
"""
Increments the timestep accumulator. Should be called after each forward pass.
"""
self.total_accumulator += 1
if self.total_accumulator % self.accumulate_every != 0:
return
self.timestep_accumulator += 1
def reset_timestep(self):
"""
Resets the timestep accumulator. Should be called after each generation / maybe before hires.fix
"""
self.timestep_accumulator = 0
def clear_cache(self):
"""
Clears the cache. Should be called after each generation.
"""
self.caches.clear()
self.timestep_accumulator = 0
self.total_accumulator = 0
self.skip_statistics.clear()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment