Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active December 19, 2023 22:07
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Birch-san/4315701264b72bb72e8eac5a529ee93a to your computer and use it in GitHub Desktop.
Save Birch-san/4315701264b72bb72e8eac5a529ee93a to your computer and use it in GitHub Desktop.
FlashAttnProcessor
import torch
from typing import Optional
from flash_attn import flash_attn_func
from diffusers.models.attention import Attention
class FlashAttnProcessor:
r"""
Processor for implementing memory efficient attention using flash_attn.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, key_tokens, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
assert attention_mask is None, 'flash_attn does not implement support for attention masks'
hidden_states = flash_attn_func(
query, key, value, dropout_p=0.0, causal=False
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = hidden_states.flatten(-2)
out_proj, dropout = attn.to_out
hidden_states = out_proj(hidden_states)
hidden_states = dropout(hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
if attn.rescale_output_factor != 1:
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
@Birch-san
Copy link
Author

Birch-san commented Jul 21, 2023

Here I provide two flash_attn AttnProcessors for diffusers. the faster one (which packs the QKV projections) can only be used with self-attention. I provide a helper for applying the two processors selectively to each Attention layer in the model.

Install flash_attn:

MAX_JOBS=4 pip install flash-attn --no-build-isolation

We reduce MAX_JOBS to 4, to avoid running out of RAM. you can reduce it further if you want.

If something goes wrong with installing it like that, you could consider building it from the latest git commit like I did, and env vars to show where CUDA is:

git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
pip install packaging
CUDA_DIR=/usr/local/cuda-12.1 PATH="$CUDA_DIR/bin:$PATH" LD_LIBRARY_PATH="$CUDA_DIR/lib64" MAX_JOBS=4 python setup.py install

Now, in your diffusers project…
Set these attention processors in your UNet like so:

from diffusers.models.attention import Attention
from torch.nn import Module, Linear
from diffusers import UNet2DConditionModel

from .flash_attn_processor import FlashAttnProcessor, FlashAttnQKVPackedProcessor

unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
  f'stabilityai/stable-diffusion-xl-base-0.9',
  torch_dtype=torch.float16,
  use_safetensors=True,
  variant='fp16',
  subfolder='unet',
).eval()

use_flash_attn_qkv_packed = True
if use_flash_attn_qkv_packed:
  cross_attn_processor = FlashAttnProcessor()
  self_attn_processor = FlashAttnQKVPackedProcessor()

  def set_flash_attn_processor(mod: Module) -> None:
    if isinstance(mod, Attention):
      # relying on a side-channel to determine (unreliably) whether a layer is self-attention
      if mod.to_k.in_features == mod.to_q.in_features:
        # probably self-attention
        mod.to_qkv = Linear(mod.to_q.in_features, mod.to_q.out_features*3, dtype=mod.to_q.weight.dtype, device=device)
        mod.to_qkv.weight.data = cat([mod.to_q.weight, mod.to_k.weight, mod.to_v.weight]).detach()
        del mod.to_q, mod.to_k, mod.to_v
        mod.set_processor(self_attn_processor)
      else:
        mod.set_processor(cross_attn_processor)
  unet.apply(set_flash_attn_processor)

@okaris
Copy link

okaris commented Dec 19, 2023

Thank you for sharing this @Birch-san why doesn’t FlashAttnQKVPackedProcessor work with cross-attention?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment