Skip to content

Instantly share code, notes, and snippets.

@wkcn
Last active April 26, 2023 14:52
Show Gist options
  • Save wkcn/65bbf94037222a38af78169f7f2c206b to your computer and use it in GitHub Desktop.
Save wkcn/65bbf94037222a38af78169f7f2c206b to your computer and use it in GitHub Desktop.
import torch
from torch import nn
import numpy as np
from flash_attn.flash_attention import FlashAttention
class Attention(nn.Module):
use_flash_attn: bool = False
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim ** -0.5
self.flash_attn = FlashAttention(attention_dropout=attn_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
if self.use_flash_attn:
return self.flash_attn_forward(x)
return self.naive_forward(x)
def flash_attn_forward(self, x):
# The input of FlashAttention is (B, N, 3, H, D)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
# (B, N, H, D)
x = self.flash_attn(qkv)[0]
x = x.flatten(2)
x = self.proj(x)
x = self.proj_drop(x)
return x
def naive_forward(self, x):
B, N, C = x.shape
# (B, N, C) -> (B, N, 3, H, C/H) -> (3, B, H, N, C/H) -> (Q, K, V)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
# (B, H, N, C/H)
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
dim_per_head = 64
num_heads = 64
dim = dim_per_head * num_heads
attn = Attention(dim, num_heads, dim_per_head)
attn.cuda()
attn.half()
B, N, C = 128, 14*14, dim
x = torch.randn(B, N, C, device='cuda', dtype=torch.float16)
y = attn(x)
attn.use_flash_attn = True
y2 = attn(x)
print('Flash Attention forward works!')
y2.sum().backward()
print('Flash Attention backward works!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment