Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created February 28, 2024 13:34
Show Gist options
  • Save laksjdjf/a5fdc2bd4aa8ad79ae16f20f0dc112ef to your computer and use it in GitHub Desktop.
Save laksjdjf/a5fdc2bd4aa8ad79ae16f20f0dc112ef to your computer and use it in GitHub Desktop.
import torch
class VisualStylePrompting:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"reference": ("LATENT",),
"depth": ("INT", {"default": 0, "min": -1, "max": 12}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
"start_step": ("FLOAT", {"default": 0,"min": 0, "max": 1, "step": 0.01}),
"end_step": ("FLOAT", {"default": 1, "min": 0, "max": 1, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL", "LATENT")
FUNCTION = "reference_only"
CATEGORY = "loaders"
def reference_only(self, model, reference, depth, batch_size, start_step, end_step):
model_reference = model.clone()
start_sigma = model_reference.model.model_sampling.percent_to_sigma(start_step)
end_sigma = model_reference.model.model_sampling.percent_to_sigma(end_step)
size_latent = list(reference["samples"].shape)
size_latent[0] = batch_size
latent = {}
latent["samples"] = torch.zeros(size_latent)
self.depth = depth
self.sdxl = hasattr(model_reference.model.diffusion_model, "label_emb")
self.num_blocks = 8 if self.sdxl else 11
def reference_apply(q, k, v, extra_options):
block_name, block_id = extra_options["block"]
if block_name == "output":
block_number = self.num_blocks - block_id
else:
block_number = 100
q = q.clone()
k = k.clone()
v = v.clone()
sigma = extra_options["sigmas"][0].item()
if end_sigma <= sigma <= start_sigma and block_number <= self.depth:
k[1:] = k[:1]
v[1:] = v[:1]
return q, k, v
model_reference.set_model_attn1_patch(reference_apply)
out_latent = torch.cat((reference["samples"], latent["samples"]))
if "noise_mask" in latent:
mask = latent["noise_mask"]
else:
mask = torch.ones(out_latent.shape[2:], dtype=torch.float32, device="cpu")
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
if mask.shape[0] < latent["samples"].shape[0]:
mask = mask.repeat(latent["samples"].shape[0], 1, 1)
out_mask = torch.zeros((1,mask.shape[1],mask.shape[2]), dtype=torch.float32, device="cpu")
return (model_reference, {"samples": out_latent, "noise_mask": torch.cat((out_mask, mask))})
NODE_CLASS_MAPPINGS = {
"VisualStylePrompting": VisualStylePrompting,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment