Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created May 15, 2024 11:05
Show Gist options
  • Save laksjdjf/4528d700e34bc5d0abb0ad24152b9a54 to your computer and use it in GitHub Desktop.
Save laksjdjf/4528d700e34bc5d0abb0ad24152b9a54 to your computer and use it in GitHub Desktop.
import torch
from einops import rearrange, repeat
def block_to_key(block):
if block[0] == "input":
return "in" + str(block[1])
elif block[0] == "output":
return "out" + str(block[1])
elif block[0] == "middle":
return "mid"
else:
return ValueError("Invalid block type")
def str_to_list(s):
return [int(x.strip()) for x in s.split(",")]
class TransformerRepeat:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"start": ("INT", {"default": 0, "min": 0,"max": 1000,"step": 1,"display": "number"}),
"end": ("INT", {"default": 1000, "min": 0,"max": 1000,"step": 1,"display": "number"}),
"in4": ("STRING", {"default": "0,1"}),
"in5": ("STRING", {"default": "0,1"}),
"in7": ("STRING", {"default": "0,1,2,3,4,5,6,7,8,9"}),
"in8": ("STRING", {"default": "0,1,2,3,4,5,6,7,8,9"}),
"mid": ("STRING", {"default": "0,1,2,3,4,5,6,7,8,9"}),
"out0": ("STRING", {"default": "0,1,2,3,4,5,6,7,8,9"}),
"out1": ("STRING", {"default": "0,1,2,3,4,5,6,7,8,9"}),
"out2": ("STRING", {"default": "0,1,2,3,4,5,6,7,8,9"}),
"out3": ("STRING", {"default": "0,1"}),
"out4": ("STRING", {"default": "0,1"}),
"out5": ("STRING", {"default": "0,1"}),
},
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "apply"
CATEGORY = "_for_testing"
def apply(self, model, start, end, **kwargs):
new_model = model.clone()
self.start = start
self.end = end
self.org_forwards = {}
self.repeat_dic = {key: str_to_list(kwargs[key]) for key in kwargs}
# unet計算前後のパッチ
def apply_model(model_function, kwargs):
sigmas = kwargs["timestep"]
t = new_model.model.model_sampling.timestep(sigmas)
if t[0] < (1000 - end) or t[0] > (1000 - start):
return model_function(kwargs["input"], kwargs["timestep"], **kwargs["c"])
self.replace_transformer(new_model)
retval = model_function(kwargs["input"], kwargs["timestep"], **kwargs["c"])
self.restore_conv2d(new_model)
return retval
new_model.set_model_unet_function_wrapper(apply_model)
return (new_model, )
def replace_transformer(self, model):
for name, module in model.model.diffusion_model.named_modules():
if module.__class__.__name__ == 'SpatialTransformer':
self.org_forwards[name] = module.forward
module.forward = self.forward_hooker(module, self.org_forwards[name])
def restore_conv2d(self, model):
for name, module in model.model.diffusion_model.named_modules():
if name in self.org_forwards:
module.forward = self.org_forwards[name]
self.org_forwards = {}
def forward_hooker(self, module, forward):
def forward_hook(x, context=None, transformer_options={}):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context] * len(module.transformer_blocks)
b, c, h, w = x.shape
x_in = x
x = module.norm(x)
if not module.use_linear:
x = module.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if module.use_linear:
x = module.proj_in(x)
for i in self.repeat_dic[block_to_key(transformer_options["block"])]:
block = module.transformer_blocks[i]
transformer_options["block_index"] = i
x = block(x, context=context[i], transformer_options=transformer_options)
if module.use_linear:
x = module.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not module.use_linear:
x = module.proj_out(x)
return x + x_in
return forward_hook
NODE_CLASS_MAPPINGS = {
"TransformerRepeat": TransformerRepeat,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment