Created
December 14, 2023 00:13
-
-
Save ttulttul/2b09f0f14bb35639ada7ed37b1f0428d to your computer and use it in GitHub Desktop.
A batch unsampler node for ComfyUI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import comfy.model_management | |
import comfy.sample | |
import logging as logger | |
class BatchUnsampler: | |
@classmethod | |
def INPUT_TYPES(s): | |
return {"required": | |
{"model": ("MODEL",), | |
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), | |
"end_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), | |
"step_increment": ("INT", {"default": 1, "min": 1, "max": 10000}), | |
"cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), | |
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), | |
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), | |
"normalize": (["disable", "enable"], ), | |
"positive": ("CONDITIONING", ), | |
"negative": ("CONDITIONING", ), | |
"latent_image": ("LATENT", ), | |
}} | |
RETURN_TYPES = ("LATENT",) | |
RETURN_NAMES = ("latent_batch",) | |
FUNCTION = "batch_unsampler" | |
CATEGORY = "tests" | |
def batch_unsampler(self, model, cfg, sampler_name, steps, end_at_step, step_increment, scheduler, normalize, positive, negative, latent_image): | |
normalize = normalize == "enable" | |
device = comfy.model_management.get_torch_device() | |
latent = latent_image | |
latent_image = latent["samples"] | |
batch_of_latents = [] | |
end_at_step = min(end_at_step, steps-1) | |
end_at_step = steps - end_at_step | |
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") | |
noise_mask = None | |
if "noise_mask" in latent: | |
noise_mask = comfy.sample.prepare_mask(latent["noise_mask"], noise, device) | |
real_model = None | |
real_model = model.model | |
noise = noise.to(device) | |
latent_image = latent_image.to(device) | |
positive = comfy.sample.convert_cond(positive) | |
negative = comfy.sample.convert_cond(negative) | |
models, inference_memory = comfy.sample.get_additional_models(positive, negative, model.model_dtype()) | |
comfy.model_management.load_models_gpu([model] + models, model.memory_required(noise.shape) + inference_memory) | |
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=1.0, model_options=model.model_options) | |
# Flip the sigmas (the sampling schedule) in reverse so that the sampler | |
# will instead "unsample" the latent, adding noise rather than | |
# removing noise. | |
sigmas = sigmas = sampler.sigmas.flip(0) + 0.0001 | |
pbar = comfy.utils.ProgressBar(steps) | |
# We define a callback that receives the intermediate latents from | |
# each step of "unsampling" and appends these to the latent batch | |
# that we will return at the output of this node. | |
def callback(step, x0, x, total_steps): | |
batch_of_latents.append(x) | |
pbar.update_absolute(step + 1, total_steps) | |
# Here, we call the sampler, sampling all the steps asked for, | |
# calling the callback function along the way. Note that some | |
# samplers do not seem to call the callback at all, resulting | |
# in no output. | |
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, force_full_denoise=False, denoise_mask=noise_mask, sigmas=sigmas, start_step=0, last_step=end_at_step, callback=callback) | |
if normalize: | |
#technically doesn't normalize because unsampling is not guaranteed to end at a std given by the schedule | |
samples -= samples.mean() | |
samples /= samples.std() | |
comfy.sample.cleanup_additional_models(models) | |
if len(batch_of_latents) > 0: | |
# Concatenate the latents into a batch and do it the Comfy | |
# way by jamming the batch into a dictionary as "samples". | |
batch_of_latents = torch.cat(batch_of_latents) | |
batch_of_latents = {'samples': batch_of_latents} | |
else: | |
# If no latents were unsampled then just return the | |
# input latent. | |
logger.warning("BatchUnsampler: No latents were produced.") | |
batch_of_latents = latent_image | |
return (batch_of_latents,) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment