Skip to content

Instantly share code, notes, and snippets.

@morrisalp
Created January 14, 2024 18:24
Show Gist options
  • Save morrisalp/2c8d150b6187a3bf50b4a89695da78e1 to your computer and use it in GitHub Desktop.
Save morrisalp/2c8d150b6187a3bf50b4a89695da78e1 to your computer and use it in GitHub Desktop.
Diffusers SDXL pipeline with gradients (overriding no_grad), tested with diffusers v25.0. Use output_type="latent" when calling pipeline to get latents with gradients.
from diffusers import StableDiffusionXLPipeline
class CustomPipeline(StableDiffusionXLPipeline):
@classmethod
def from_pretrained(cls, *args, **kwargs):
self = super().from_pretrained(*args, **kwargs)
assert self.watermark is None # watermarking currently not supported
def postprocess_no_grad(image, *args, **kwargs):
return self.image_processor.__class__.postprocess(
self.image_processor, image.detach(), *args, **kwargs)
self.image_processor.postprocess = postprocess_no_grad
return self
def __call__(self, *args, **kwargs):
return super().__call__.__wrapped__(self, *args, **kwargs)
# ^ __wrapped__: removes @torch.no_grad decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment