Last active
April 26, 2024 13:42
-
-
Save blepping/02e389f660112097983684a8ea8093b1 to your computer and use it in GitHub Desktop.
Experimental MSW-MSA attention node for ComfyUI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# By https://github.com/blepping | |
# License: Apache 2.0 | |
# Experimental MSW-MSA attention implementation ported to ComfyUI from: https://github.com/megvii-research/HiDiffusion | |
# Lightly tested, may or may not actually work correctly. | |
# | |
# *** NOTE *** | |
# This is *NOT* a full implementation of HiDiffusion, only the MSW-MSA attention component which is mainly | |
# for performance. By itself it will not enable generating at higher resolution than the model normally supports. | |
# | |
# Usage: Copy into custom_nodes directory, connect the ApplyMSWMSAAttention node. | |
# The default block values are for SD1.5, for SDXL I think you'd use: input=4,5 output=3,4,5 | |
# SDXL note: This doesn't seem to help much for performance. Even at high resolution and using a slow sampler | |
# I didn't get more than 5-10%. | |
# You may optionally set a start and end time range. | |
import torch | |
class ApplyMSWMSAAttention: | |
RETURN_TYPES = ("MODEL",) | |
FUNCTION = "patch" | |
CATEGORY = "model_patches" | |
@classmethod | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"input_blocks": ("STRING", {"default": "0,1"}), | |
"middle_blocks": ("STRING", {"default": ""}), | |
"output_blocks": ("STRING", {"default": "9,10,11"}), | |
"time_mode": ( | |
( | |
"percent", | |
"timestep", | |
"sigma", | |
), | |
), | |
"start_time": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 999.0}), | |
"end_time": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 999.0}), | |
"model": ("MODEL",), | |
}, | |
} | |
# reference: https://github.com/microsoft/Swin-Transformer | |
# Window functions adapted from https://github.com/megvii-research/HiDiffusion | |
@staticmethod | |
def window_partition(x, window_size, shift_size, height, width) -> torch.Tensor: | |
batch, _features, channels = x.shape | |
x = x.view(batch, height, width, channels) | |
if not isinstance(shift_size, (list, tuple)): | |
shift_size = (shift_size,) * 2 | |
if shift_size[0] + shift_size[1] > 0: | |
x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) | |
x = x.view( | |
batch, | |
height // window_size[0], | |
window_size[0], | |
width // window_size[1], | |
window_size[1], | |
channels, | |
) | |
windows = ( | |
x.permute(0, 1, 3, 2, 4, 5) | |
.contiguous() | |
.view(-1, window_size[0], window_size[1], channels) | |
) | |
return windows.view(-1, window_size[0] * window_size[1], channels) | |
@staticmethod | |
def window_reverse(windows, window_size, shift_size, height, width) -> torch.Tensor: | |
batch, features, channels = windows.shape | |
windows = windows.view(-1, window_size[0], window_size[1], channels) | |
batch = int( | |
windows.shape[0] / (height * width / window_size[0] / window_size[1]), | |
) | |
x = windows.view( | |
batch, | |
height // window_size[0], | |
width // window_size[1], | |
window_size[0], | |
window_size[1], | |
-1, | |
) | |
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, height, width, -1) | |
if not isinstance(shift_size, (list, tuple)): | |
shift_size = (shift_size,) * 2 | |
if shift_size[0] + shift_size[1] > 0: | |
x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) | |
return x.view(batch, height * width, channels) | |
@classmethod | |
def parse_blocks(cls, name, s): | |
vals = (rawval.strip() for rawval in s.split(",")) | |
return {(name, int(val.strip())) for val in vals if val} | |
def patch( | |
self, | |
model, | |
input_blocks, | |
middle_blocks, | |
output_blocks, | |
time_mode, | |
start_time, | |
end_time, | |
): | |
use_blocks = self.parse_blocks("input", input_blocks) | |
use_blocks |= self.parse_blocks("middle", middle_blocks) | |
use_blocks |= self.parse_blocks("output", output_blocks) | |
window_args = None | |
model = model.clone() | |
ms = model.get_model_object("model_sampling") | |
match time_mode: | |
case "sigma": | |
def check_time(sigma): | |
return sigma <= start_time and sigma >= end_time | |
case "percent": | |
if start_time > 1.0 or start_time < 0.0: | |
raise ValueError( | |
"invalid value for start percent", | |
) | |
if end_time > 1.0 or end_time < 0.0: | |
raise ValueError( | |
"invalid value for end percent", | |
) | |
start_sigma = ms.percent_to_sigma(start_time) | |
end_sigma = ms.percent_to_sigma(end_time) | |
def check_time(sigma): | |
return sigma <= start_sigma and sigma >= end_sigma | |
case "timestep": | |
def check_time(sigma): | |
timestep = ms.timestep(sigma) | |
return timestep <= start_time and timestep >= end_time | |
case _: | |
raise ValueError("invalid time mode") | |
def attn1_patch(n, context_attn1, value_attn1, extra_options): | |
nonlocal window_args | |
window_args = None | |
block = extra_options["block"] | |
if block not in use_blocks or not check_time(extra_options["sigmas"].max()): | |
return n, context_attn1, value_attn1 | |
# MSW-MSA | |
batch, features, channels = n.shape | |
orig_height, orig_width = extra_options["original_shape"][-2:] | |
downsample_ratio = int( | |
((orig_height * orig_width) // features) ** 0.5, | |
) | |
height, width = ( | |
orig_height // downsample_ratio, | |
orig_width // downsample_ratio, | |
) | |
window_size = (height // 2, width // 2) | |
match int(torch.rand(1).item() * 4): | |
case 0: | |
shift_size = (0, 0) | |
case 1: | |
shift_size = (window_size[0] // 4, window_size[1] // 4) | |
case 2: | |
shift_size = (window_size[0] // 4 * 2, window_size[1] // 4 * 2) | |
case _: | |
shift_size = (window_size[0] // 4 * 3, window_size[1] // 4 * 3) | |
window_args = (window_size, shift_size, height, width) | |
result = self.window_partition(n, *window_args) | |
return (result,) * 3 | |
def attn1_output_patch(n, _extra_options): | |
nonlocal window_args | |
if window_args is None: | |
return n | |
result = self.window_reverse(n, *window_args) | |
window_args = None | |
return result | |
model.set_model_attn1_patch(attn1_patch) | |
model.set_model_attn1_output_patch(attn1_output_patch) | |
return (model,) | |
NODE_CLASS_MAPPINGS = { | |
"ApplyMSWMSAAttention": ApplyMSWMSAAttention, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment