Skip to content

Instantly share code, notes, and snippets.

@blepping
Last active April 26, 2024 13:42
Show Gist options
  • Save blepping/02e389f660112097983684a8ea8093b1 to your computer and use it in GitHub Desktop.
Save blepping/02e389f660112097983684a8ea8093b1 to your computer and use it in GitHub Desktop.
Experimental MSW-MSA attention node for ComfyUI
# By https://github.com/blepping
# License: Apache 2.0
# Experimental MSW-MSA attention implementation ported to ComfyUI from: https://github.com/megvii-research/HiDiffusion
# Lightly tested, may or may not actually work correctly.
#
# *** NOTE ***
# This is *NOT* a full implementation of HiDiffusion, only the MSW-MSA attention component which is mainly
# for performance. By itself it will not enable generating at higher resolution than the model normally supports.
#
# Usage: Copy into custom_nodes directory, connect the ApplyMSWMSAAttention node.
# The default block values are for SD1.5, for SDXL I think you'd use: input=4,5 output=3,4,5
# SDXL note: This doesn't seem to help much for performance. Even at high resolution and using a slow sampler
# I didn't get more than 5-10%.
# You may optionally set a start and end time range.
import torch
class ApplyMSWMSAAttention:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input_blocks": ("STRING", {"default": "0,1"}),
"middle_blocks": ("STRING", {"default": ""}),
"output_blocks": ("STRING", {"default": "9,10,11"}),
"time_mode": (
(
"percent",
"timestep",
"sigma",
),
),
"start_time": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 999.0}),
"end_time": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 999.0}),
"model": ("MODEL",),
},
}
# reference: https://github.com/microsoft/Swin-Transformer
# Window functions adapted from https://github.com/megvii-research/HiDiffusion
@staticmethod
def window_partition(x, window_size, shift_size, height, width) -> torch.Tensor:
batch, _features, channels = x.shape
x = x.view(batch, height, width, channels)
if not isinstance(shift_size, (list, tuple)):
shift_size = (shift_size,) * 2
if shift_size[0] + shift_size[1] > 0:
x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
x = x.view(
batch,
height // window_size[0],
window_size[0],
width // window_size[1],
window_size[1],
channels,
)
windows = (
x.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(-1, window_size[0], window_size[1], channels)
)
return windows.view(-1, window_size[0] * window_size[1], channels)
@staticmethod
def window_reverse(windows, window_size, shift_size, height, width) -> torch.Tensor:
batch, features, channels = windows.shape
windows = windows.view(-1, window_size[0], window_size[1], channels)
batch = int(
windows.shape[0] / (height * width / window_size[0] / window_size[1]),
)
x = windows.view(
batch,
height // window_size[0],
width // window_size[1],
window_size[0],
window_size[1],
-1,
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, height, width, -1)
if not isinstance(shift_size, (list, tuple)):
shift_size = (shift_size,) * 2
if shift_size[0] + shift_size[1] > 0:
x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
return x.view(batch, height * width, channels)
@classmethod
def parse_blocks(cls, name, s):
vals = (rawval.strip() for rawval in s.split(","))
return {(name, int(val.strip())) for val in vals if val}
def patch(
self,
model,
input_blocks,
middle_blocks,
output_blocks,
time_mode,
start_time,
end_time,
):
use_blocks = self.parse_blocks("input", input_blocks)
use_blocks |= self.parse_blocks("middle", middle_blocks)
use_blocks |= self.parse_blocks("output", output_blocks)
window_args = None
model = model.clone()
ms = model.get_model_object("model_sampling")
match time_mode:
case "sigma":
def check_time(sigma):
return sigma <= start_time and sigma >= end_time
case "percent":
if start_time > 1.0 or start_time < 0.0:
raise ValueError(
"invalid value for start percent",
)
if end_time > 1.0 or end_time < 0.0:
raise ValueError(
"invalid value for end percent",
)
start_sigma = ms.percent_to_sigma(start_time)
end_sigma = ms.percent_to_sigma(end_time)
def check_time(sigma):
return sigma <= start_sigma and sigma >= end_sigma
case "timestep":
def check_time(sigma):
timestep = ms.timestep(sigma)
return timestep <= start_time and timestep >= end_time
case _:
raise ValueError("invalid time mode")
def attn1_patch(n, context_attn1, value_attn1, extra_options):
nonlocal window_args
window_args = None
block = extra_options["block"]
if block not in use_blocks or not check_time(extra_options["sigmas"].max()):
return n, context_attn1, value_attn1
# MSW-MSA
batch, features, channels = n.shape
orig_height, orig_width = extra_options["original_shape"][-2:]
downsample_ratio = int(
((orig_height * orig_width) // features) ** 0.5,
)
height, width = (
orig_height // downsample_ratio,
orig_width // downsample_ratio,
)
window_size = (height // 2, width // 2)
match int(torch.rand(1).item() * 4):
case 0:
shift_size = (0, 0)
case 1:
shift_size = (window_size[0] // 4, window_size[1] // 4)
case 2:
shift_size = (window_size[0] // 4 * 2, window_size[1] // 4 * 2)
case _:
shift_size = (window_size[0] // 4 * 3, window_size[1] // 4 * 3)
window_args = (window_size, shift_size, height, width)
result = self.window_partition(n, *window_args)
return (result,) * 3
def attn1_output_patch(n, _extra_options):
nonlocal window_args
if window_args is None:
return n
result = self.window_reverse(n, *window_args)
window_args = None
return result
model.set_model_attn1_patch(attn1_patch)
model.set_model_attn1_output_patch(attn1_output_patch)
return (model,)
NODE_CLASS_MAPPINGS = {
"ApplyMSWMSAAttention": ApplyMSWMSAAttention,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment