Skip to content

Instantly share code, notes, and snippets.

@takuma104
Last active July 29, 2023 20:09
Show Gist options
  • Save takuma104/9d25bb87ae3b52e41e0132aa737c0b03 to your computer and use it in GitHub Desktop.
Save takuma104/9d25bb87ae3b52e41e0132aa737c0b03 to your computer and use it in GitHub Desktop.
Generating (almost) reproducible pictures using Diffusers with xFormers

This gist is out of date. Please see this gist instead.


Generating (almost) reproducible pictures using Diffusers with xFormers

fig_xformers_flash_attention_seed_7

When using Diffusers with xFormers enabled for Stable Diffusion, I noticed a problem where the pictures generated are slightly different (like a mistaken search) even when the seed value is fixed.

According to the xFormers team, the default backend, the Cutlass backend, does not guarantee deterministic behavior. (Incidentally, a patch has been merged into the main branch that adds a sentence to the documentation stating non-deterministic behavior and a warning if torch.use_deterministic_algorithms is set to enabled.)

In the same thread, I was informed that another backend of xFormers, Flash Attention, has deterministic behavior. So I wrote a patch to prioritize the Flash Attention backend.

The result of the test is shown in the figure above. Each row on the leftmost image is run#0, and the remaining image shows the difference from run#0. (There is less difference closer to black.) The seed value is fixed, and the generation is repeated. The cat's paw changes in the default (Cutlass) case, but when Flash Attention is prioritized, almost the same image is generated each time.

There is a reason why I used the word "almost." Flash Attention can be used for U-Net inference but not for VAE inference due to the size of the shape of the attention. For this reason, I have changed the code to fallback to the existing Cutlass backend when flash attention cannot use. However, using the Cutlass backend during VAE inference results in a very small (1 for a maximum luminance of 255 or so) difference, which is almost unrecognizable to humans.

Patched Diffusers (work in progress)

https://github.com/takuma104/diffusers/tree/force_xformers_flash_attention

diff: https://github.com/huggingface/diffusers/compare/main...takuma104:diffusers:force_xformers_flash_attention

The essential part of this patch is the following: If use_flash_attention is True, the op argument of memory_efficient_attention() is determined to prioritize FlashAttention. If the Flash Attention doesn't support the argument shape, type or so on, it fallbacks to the None == default (Cutlass).

if use_flash_attention:
    op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
    fw, bw = op
    if not fw.supports(xformers.ops.fmha.Inputs(query=query, key=key, value=value, attn_bias=attention_mask)):
        logger.warning('Flash Attention is not availabe for the input arguments. Fallback to default xFormers\' backend.')                
        op = None
else:
    op = None
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask, op=op)

To generate the above figure

by attached diffusers_sd_xformers_flash_attention.py

However, as I have written above, you may not be able to reproduce the same results for the default one.

Simple Benchmarking

by attached diffusers_sd_xformers_flash_attention_profile.py

In my environment (RTX3060) the results were as follows. Using Flash Attention for Unet inference is slightly faster, and peak memory usage is same as default.

$ python diffusers_xformers_profile.py
A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'
Fetching 16 files: 100%|███████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 26163.30it/s]
default auto backend ---
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:04<00:00,  3.62it/s]
Peak memory use: 3849MB
flash attention backend ---
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:03<00:00,  4.09it/s]
Flash Attention is not availabe for the input arguments. Fallback to default xFormers' backend.
Peak memory use: 3849MB
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
plt.rcParams["figure.figsize"] = (10,5)
plt.rcParams['figure.facecolor'] = 'white'
def generate_tuxedo_cat_picture(fn_prefix, seed=0):
prompt = "a tuxedo cat, oil painting"
for n in range(4):
generator = torch.Generator(device='cuda').manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=15,
guidance_scale=7.5).images[0]
image.save(f"{fn_prefix}_{n}.png")
def calc_difference_image(image0, image1, normalize=False):
image0 = np.array(image0, dtype=np.int32)
image1 = np.array(image1, dtype=np.int32)
abs_diff = np.abs(image0 - image1)
if normalize:
abs_diff = abs_diff / abs_diff.max() * 255
return Image.fromarray(abs_diff.astype(np.uint8))
def render_figure(fn):
def plot_row(axs, fn_prefix, name):
for i, ax in enumerate(axs):
if i == 0:
ax.set_title(f'Ref ({name})')
ax.imshow(Image.open(f'{fn_prefix}_{i}.png'))
else:
ax.set_title(f'Ref - Result#{i} ({name})')
ref = Image.open(f'{fn_prefix}_{0}.png')
res = Image.open(f'{fn_prefix}_{i}.png')
ax.imshow(calc_difference_image(ref, res))
fig, axs = plt.subplots(2, 4)
for ax in axs.flatten():
ax.set_aspect('equal', 'box')
ax.axis('off')
plot_row(axs[0], 'xformers_default', 'default')
plot_row(axs[1], 'xformers_flash_attention', 'flashattn.')
fig.tight_layout()
fig.savefig(fn)
if __name__ == '__main__':
model_id = "stabilityai/stable-diffusion-2-1"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
seed = 7 # cherry-picked value
# you mey repeat the following lines with different seed values
# to see the results with different seed values
# Enable xFormers and leave the selection of the operator to the xformers as
# before. This will be usually selected the Cutlass backend.
print('default auto backend ---')
pipe.enable_xformers_memory_efficient_attention()
generate_tuxedo_cat_picture('xformers_default', seed=seed)
# Enable xFormers and force to use flash attention. (new)
print('flash attention backend ---')
pipe.enable_xformers_memory_efficient_attention(use_flash_attention=True)
generate_tuxedo_cat_picture('xformers_flash_attention', seed=seed)
render_figure(f'fig_xformers_flash_attention_seed_{seed}.png')
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
# https://github.com/facebookresearch/xformers/blob/main/HOWTO.md
def mem_profile_start():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
def mem_profile_end():
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() // 2 ** 20
print(f"Peak memory use: {max_memory}MB")
def generate_tuxedo_cat_picture(fn_prefix, seed=0):
prompt = "a tuxedo cat, oil painting"
for n in range(1):
generator = torch.Generator(device='cuda').manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=15,
guidance_scale=7.5).images[0]
image.save(f"{fn_prefix}_{n}.png")
if __name__ == '__main__':
model_id = "stabilityai/stable-diffusion-2-1"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
seed = 7 # cherry-picked value
# you mey repeat the following lines with different seed values
# to see the results with different seed values
# Enable xFormers and leave the selection of the operator to the xformers as
# before. This will be usually selected the Cutlass backend.
print('default auto backend ---')
pipe.enable_xformers_memory_efficient_attention()
mem_profile_start()
generate_tuxedo_cat_picture('xformers_default', seed=seed)
mem_profile_end()
# Enable xFormers and force to use flash attention. (new)
print('flash attention backend ---')
pipe.enable_xformers_memory_efficient_attention(use_flash_attention=True)
mem_profile_start()
generate_tuxedo_cat_picture('xformers_flash_attention', seed=seed)
mem_profile_end()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment