Created
December 8, 2023 19:35
-
-
Save aria1th/04bb78207daeee1f3d0800dc422e6254 to your computer and use it in GitHub Desktop.
DeepCacheStandAlone
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DeepCacheStandAlone: | |
""" | |
@source https://github.com/horseee/DeepCache | |
Standalone version of DeepCache, which can be used without the DeepCacheScript. | |
For multiple switching UNets, you can specify cache_type to use different caches. | |
Code Snippet: | |
```python | |
# U-Net Encoder | |
for i, module in enumerate(self.input_blocks): | |
if session.should_skip_unet_in(i): | |
continue | |
h = forward_timestep_embed(module, h, emb, context) | |
hs.append(h) | |
if not session.should_skip_unet_middle(): | |
h = forward_timestep_embed(self.middle_block, h, emb, context) | |
# U-Net Decoder | |
total_out_blocks = len(self.output_blocks) | |
for idx, module in enumerate(self.output_blocks): | |
if session.cond_skip_unet_out(idx, total_out_blocks): | |
continue | |
if session.should_load_unet_out(idx, total_out_blocks): | |
h = session.get_cache() | |
else: | |
hsp = hs.pop() | |
h = torch.cat([h, hsp], dim=1) | |
del hsp | |
if len(hs) > 0: | |
output_shape = hs[-1].shape | |
else: | |
output_shape = None | |
h = forward_timestep_embed(module, h, emb, context, output_shape=output_shape) | |
if session.should_cache_unet_out(idx, total_out_blocks): | |
session.put_cache(h) | |
""" | |
def __init__(self, params:dict = None, enable:bool = True) -> None: | |
if params is None: | |
params = {} | |
from collections import defaultdict | |
self.cache_in_level = params.get('cache_in_level', 0) | |
self.cache_enable_step = params.get('cache_enable_step', 8) | |
self.full_run_step_rate = params.get('full_run_step_rate', 3) | |
self.caches = defaultdict(lambda: {"timestep": set()}) | |
self.timestep_accumulator = 0 | |
self.enable = enable | |
self.debug_mode = True | |
self.skip_statistics = {} # will log in_true / in_false / out_true / out_false / middle_true / middle_false | |
self.total_accumulator = 0 | |
self.accumulate_every = 2 # controlnet forward every 2 steps | |
def report(self): | |
""" | |
Reports the cache statistics. | |
""" | |
print("DeepCache : report") | |
# total, in, out, middle success rate and count | |
total = sum(self.skip_statistics.values()) | |
in_sum = self.skip_statistics.get('in_true', 0) + self.skip_statistics.get('in_false', 0) | |
out_sum = self.skip_statistics.get('out_true', 0) + self.skip_statistics.get('out_false', 0) | |
middle_sum = self.skip_statistics.get('middle_true', 0) + self.skip_statistics.get('middle_false', 0) | |
print(f"DeepCache : total {total} in {in_sum} out {out_sum} middle {middle_sum}") | |
print("Success rates") | |
print(f"IN : {self.skip_statistics.get('in_true', 0) / in_sum if in_sum > 0 else 0}, total {in_sum}") | |
print(f"OUT : {self.skip_statistics.get('out_true', 0) / out_sum if out_sum > 0 else 0}, total {out_sum}") | |
print(f"MIDDLE : {self.skip_statistics.get('middle_true', 0) / middle_sum if middle_sum > 0 else 0}, total {middle_sum}") | |
def should_activate(self): | |
""" | |
Returns if DeepCache should be activated. | |
""" | |
if not self.enable: | |
self.debug_print("DeepCache : Disabled") | |
return False | |
if self.full_run_step_rate < 1: | |
self.debug_print(f"DeepCache : full_run_step_rate {self.full_run_step_rate} < 1") | |
return False | |
if not self.timestep_accumulator > self.cache_enable_step: | |
self.debug_print(f"DeepCache : timestep {self.timestep_accumulator} <= cache_enable_step {self.cache_enable_step}") | |
return False | |
return True | |
def can_use_cache(self, cache_type="last"): | |
""" | |
Returns if DeepCache can be used. | |
""" | |
return self.get_cache(cache_type) is not None and self.timestep_accumulator % self.full_run_step_rate != 0 and self.timestep_accumulator > self.cache_enable_step | |
def should_cache_unet_out(self, index, total_out_blocks:int) -> bool: | |
""" | |
Returns if current in block should be cached. | |
The function should be called exclusively with should_load_unet_out branch, to execute forward+cache / load+continue. | |
Usage : if self.should_cache_unet_out(index): h = model.forward();self.put_cache(h) | |
""" | |
if not self.should_activate() or self.timestep_accumulator % self.full_run_step_rate != 0: | |
return False | |
return index == total_out_blocks - self.cache_in_level - 1 | |
def should_load_unet_out(self, index, total_out_blocks:int, cache_type="last") -> bool: | |
""" | |
Returns if current out block should be loaded. | |
The function should be called before the forward_timestep_embed call, to replace the input with the cached tensor. | |
The function should be called before should_cache_unet_out call. | |
Usage : if should_load_unet_out(index, len(unet.output_blocks)): h = self.get_cache() | |
""" | |
return self.can_use_cache(cache_type) and index == total_out_blocks - self.cache_in_level - 1 | |
def cond_skip_unet_in(self, index, cache_type="last") -> bool: | |
""" | |
Returns if current in block should be skipped. | |
The function should be called before the forward_timestep_embed call. | |
Usage : if self.cond_skip_unet_in(index): continue | |
""" | |
result = self.can_use_cache(cache_type) and index > self.cache_in_level | |
if result: | |
self.skip_statistics['in_true'] = self.skip_statistics.get('in_true', 0) + 1 | |
else: | |
self.skip_statistics['in_false'] = self.skip_statistics.get('in_false', 0) + 1 | |
def cond_skip_unet_out(self, index, total_out_blocks, cache_type="last") -> bool: | |
""" | |
Returns if current out block should be skipped. | |
The function should be called before should_load_net_out and should_cache_unet_out call. | |
Usage : if self.cond_skip_unet_out(index, len(unet.output_blocks)): continue | |
""" | |
result = self.can_use_cache(cache_type) and index < total_out_blocks - self.cache_in_level - 1 | |
if result: | |
self.skip_statistics['out_true'] = self.skip_statistics.get('out_true', 0) + 1 | |
else: | |
self.skip_statistics['out_false'] = self.skip_statistics.get('out_false', 0) + 1 | |
def cond_skip_unet_middle(self, cache_type="last") -> bool: | |
""" | |
Returns if middle block should be skipped. | |
Usage : if self.cond_skip_unet_middle(): continue | |
""" | |
timestep = self.timestep_accumulator | |
result = self.can_use_cache(cache_type) | |
if result: | |
self.skip_statistics['middle_true'] = self.skip_statistics.get('middle_true', 0) + 1 | |
else: | |
self.skip_statistics['middle_false'] = self.skip_statistics.get('middle_false', 0) + 1 | |
def put_cache(self, h, cache_type="last"): | |
""" | |
Registers cache | |
Usage : if self.should_cache_unet_out(index): self.put_cache(h) | |
""" | |
timestep = self.timestep_accumulator | |
target_cache = self.caches[cache_type] | |
target_cache["timestep"].add(timestep) | |
assert h is not None, "Cannot cache None" | |
# maybe move to cpu and load later for low vram? | |
target_cache["last"] = h | |
for _i in range(self.full_run_step_rate): | |
# register for each step too | |
target_cache[f"timestep_{timestep + _i}"] = h | |
self.debug_print(f"DeepCache : put cache for timestep {timestep}") | |
def debug_print(self, *args, **kwargs): | |
if self.debug_mode: | |
print(*args, **kwargs) | |
def get_cache(self, cache_type="last"): | |
""" | |
Returns the cached tensor for the given timestep and cache key. | |
Usage : if self.should_load_unet_out(index, len(unet.output_blocks)): h = self.get_cache() | |
""" | |
target_cache = self.caches[cache_type] | |
if not self.should_activate(): | |
self.debug_print("DeepCache : Disabled") | |
return None | |
current_timestep = self.timestep_accumulator | |
if current_timestep < self.cache_enable_step: | |
self.debug_print(f"DeepCache : timestep {current_timestep} < cache_enable_step {self.cache_enable_step}") | |
return None | |
elif self.full_run_step_rate < 1: | |
self.debug_print(f"DeepCache : full_run_step_rate {self.full_run_step_rate} < 1") | |
return None | |
elif current_timestep % self.full_run_step_rate != 0: | |
if f"timestep_{current_timestep}" in target_cache: | |
target_cache["last"] = target_cache[f"timestep_{current_timestep}"] # update last | |
self.debug_print(f"DeepCache : load cache for timestep {current_timestep}") | |
return target_cache[f"timestep_{current_timestep}"] | |
self.debug_print(f"DeepCache : cache for timestep {current_timestep} not found") | |
self.debug_print(f"DeepCache : Step is divisible, running full forward") | |
return None | |
def increment_timestep(self): | |
""" | |
Increments the timestep accumulator. Should be called after each forward pass. | |
""" | |
self.total_accumulator += 1 | |
if self.total_accumulator % self.accumulate_every != 0: | |
return | |
self.timestep_accumulator += 1 | |
def reset_timestep(self): | |
""" | |
Resets the timestep accumulator. Should be called after each generation / maybe before hires.fix | |
""" | |
self.timestep_accumulator = 0 | |
def clear_cache(self): | |
""" | |
Clears the cache. Should be called after each generation. | |
""" | |
self.caches.clear() | |
self.timestep_accumulator = 0 | |
self.total_accumulator = 0 | |
self.skip_statistics.clear() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment