Created
July 21, 2023 17:41
-
-
Save Birch-san/a93eb2e17d9f7b656214cb7e514b9269 to your computer and use it in GitHub Desktop.
diffusers flash_attn AttnProcessors for qkvpacked self-attn and regular cross-attn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from typing import Optional | |
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func | |
from diffusers.models.attention import Attention | |
class FlashAttnProcessor: | |
r""" | |
Processor for implementing memory efficient attention using flash_attn. | |
""" | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, key_tokens, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) | |
if attention_mask is not None: | |
# expand our mask's singleton query_tokens dimension: | |
# [batch*heads, 1, key_tokens] -> | |
# [batch*heads, query_tokens, key_tokens] | |
# so that it can be added as a bias onto the attention scores that xformers computes: | |
# [batch*heads, query_tokens, key_tokens] | |
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | |
_, query_tokens, _ = hidden_states.shape | |
attention_mask = attention_mask.expand(-1, query_tokens, -1) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
query = query.unflatten(-1, (attn.heads, -1)) | |
key = key.unflatten(-1, (attn.heads, -1)) | |
value = value.unflatten(-1, (attn.heads, -1)) | |
assert attention_mask is None, 'flash_attn does not implement support for attention masks' | |
hidden_states = flash_attn_func( | |
query, key, value, dropout_p=0.0, causal=False | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
hidden_states = hidden_states.flatten(-2) | |
out_proj, dropout = attn.to_out | |
hidden_states = out_proj(hidden_states) | |
hidden_states = dropout(hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
if attn.rescale_output_factor != 1: | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class FlashAttnQKVPackedProcessor: | |
r""" | |
Processor for implementing memory efficient self-attention using flash_attn_qkvpacked_func. | |
""" | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, key_tokens, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) | |
if attention_mask is not None: | |
# expand our mask's singleton query_tokens dimension: | |
# [batch*heads, 1, key_tokens] -> | |
# [batch*heads, query_tokens, key_tokens] | |
# so that it can be added as a bias onto the attention scores that xformers computes: | |
# [batch*heads, query_tokens, key_tokens] | |
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | |
_, query_tokens, _ = hidden_states.shape | |
attention_mask = attention_mask.expand(-1, query_tokens, -1) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
qkv = attn.to_qkv(hidden_states) | |
qkv = qkv.unflatten(-1, (3, attn.heads, -1)) | |
assert attention_mask is None, 'flash_attn does not implement support for attention masks' | |
hidden_states = flash_attn_qkvpacked_func( | |
qkv, dropout_p=0.0, causal=False | |
) | |
hidden_states = hidden_states.to(qkv.dtype) | |
hidden_states = hidden_states.flatten(-2) | |
out_proj, dropout = attn.to_out | |
hidden_states = out_proj(hidden_states) | |
hidden_states = dropout(hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
if attn.rescale_output_factor != 1: | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here I provide two
flash_attn
AttnProcessor
s for diffusers. the faster one (which packs the QKV projections) can only be used with self-attention. I provide a helper for applying the two processors selectively to each Attention layer in the model.Install
flash_attn
:We reduce
MAX_JOBS
to 4, to avoid running out of RAM. you can reduce it further if you want.If something goes wrong with installing it like that, you could consider building it from the latest git commit like I did, and env vars to show where CUDA is:
Now, in your diffusers project…
Set these attention processors in your UNet like so: