Skip to content

Instantly share code, notes, and snippets.

@takuma104
Created June 2, 2023 17:44
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save takuma104/96d241f0fd6843c231791db0d4a5c4a9 to your computer and use it in GitHub Desktop.
Save takuma104/96d241f0fd6843c231791db0d4a5c4a9 to your computer and use it in GitHub Desktop.
import torch
import json
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.models.attention import Attention
from diffusers.models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
)
from PIL import Image
def image_grid(imgs, rows=2, cols=2):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def on_off(cond):
return "ON" if cond else "OFF"
def print_memory_usage(width, height, batch, xformers, with_lora):
mem_bytes = torch.cuda.max_memory_allocated()
mem_MB = int(mem_bytes / (10**6))
dict = {
"width": width,
"height": height,
"batch": batch,
"xformers": on_off(xformers),
"lora": on_off(with_lora),
"mem_MB": mem_MB,
}
print(json.dumps(dict))
def check_attn_processor(root_module, klass):
for _, module in root_module.named_modules():
if isinstance(module, Attention):
assert isinstance(module.processor, klass)
if __name__ == "__main__":
prompt = "masterpiece, best quality, 1girl, at dusk"
negative_prompt = (
"(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), "
"bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), large breasts"
)
sd_model_id = "runwayml/stable-diffusion-v1-5"
lora_weight_model_id = "sayakpaul/civitai-light-shadow-lora"
lora_weight_name = "light_and_shadow.safetensors"
for xformers in [False, True]:
for batch in [4]:
for width, height in [(512, 768)]:
for with_lora in [False, True]:
torch.cuda.reset_peak_memory_stats()
pipe = StableDiffusionPipeline.from_pretrained(
sd_model_id, torch_dtype=torch.float16, safety_checker=None
).to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config, use_karras_sigmas=True
)
if xformers:
pipe.enable_xformers_memory_efficient_attention()
check_attn_processor(pipe.unet, XFormersAttnProcessor)
else:
pipe.disable_xformers_memory_efficient_attention()
check_attn_processor(pipe.unet, AttnProcessor2_0)
# pipe.set_progress_bar_config(disable=True)
if with_lora:
pipe.load_lora_weights(lora_weight_model_id, weight_name=lora_weight_name)
if xformers:
check_attn_processor(pipe.unet, LoRAXFormersAttnProcessor)
else:
check_attn_processor(pipe.unet, LoRAAttnProcessor2_0)
images = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=15,
num_images_per_prompt=batch,
generator=torch.manual_seed(0),
).images
image_grid(images).save(f'generated_xf-{on_off(xformers)}_lora-{on_off(with_lora)}.png')
print_memory_usage(width, height, batch, xformers, with_lora)
del pipe
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment