Skip to content

Instantly share code, notes, and snippets.

@deltheil
Created October 12, 2023 08:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save deltheil/9ba2c1840c3380e598698a16e0e09e3d to your computer and use it in GitHub Desktop.
Save deltheil/9ba2c1840c3380e598698a16e0e09e3d to your computer and use it in GitHub Desktop.
Self-Attention Guidance with Refiners: prerequisites
# Improving Sample Quality of Diffusion Models Using Self-Attention Guidance
# https://arxiv.org/abs/2210.00939
# https://github.com/SusungHong/Self-Attention-Guidance
# (1) Install (see also https://github.com/finegrain-ai/refiners#install)
git clone https://github.com/finegrain-ai/refiners.git
cd refiners
poetry install --all-extras
poetry run pip install --upgrade torch torchvision
# (2) Convert weights into Refiners format
poetry shell
python scripts/conversion/convert_transformers_clip_text_model.py --to clip_text.safetensors --half
python scripts/conversion/convert_diffusers_unet.py --to unet.safetensors --half
python scripts/conversion/convert_diffusers_autoencoder_kl.py --to lda.safetensors --half
@deltheil
Copy link
Author

Alternatively, with SDXL convert weights as follows:

python scripts/conversion/convert_transformers_clip_text_model.py --from "stabilityai/stable-diffusion-xl-base-1.0" --subfolder2 text_encoder_2 --to clip_text_xl.safetensors --half
python scripts/conversion/convert_diffusers_unet.py --from "stabilityai/stable-diffusion-xl-base-1.0" --to unet_xl.safetensors --half
python scripts/conversion/convert_diffusers_autoencoder_kl.py --from "madebyollin/sdxl-vae-fp16-fix" --subfolder "" --to lda_xl.safetensors --half

And then:

import torch

from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL
from refiners.fluxion.utils import manual_seed


device = "cuda"

sdxl = StableDiffusion_XL(device=device, dtype=torch.float16)
sdxl.clip_text_encoder.load_from_safetensors("clip_text_xl.safetensors")
sdxl.lda.load_from_safetensors("lda_xl.safetensors")
sdxl.unet.load_from_safetensors("unet_xl.safetensors")

with torch.no_grad():
    prompt = "a cute cat, detailed high-quality professional image"
    negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"

    clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
        text=prompt, negative_text=negative_prompt
    )
    time_ids = sdxl.default_time_ids

    sdxl.set_self_attention_guidance(enable=True, scale=0.75)

    manual_seed(seed=2)
    x = torch.randn(1, 4, 128, 128, device=device, dtype=torch.float16)

    for step in sdxl.steps:
        x = sdxl(
            x,
            step=step,
            clip_text_embedding=clip_text_embedding,
            pooled_text_embedding=pooled_text_embedding,
            time_ids=time_ids,
            condition_scale=5,
        )
    predicted_image = sdxl.lda.decode_latents(x=x)

predicted_image.save("output.png")
print("done: see output.png")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment