Skip to content

Instantly share code, notes, and snippets.

@jaanli
Created February 1, 2024 22:50
Show Gist options
  • Save jaanli/c42db1454f540171516ffae08e8d9454 to your computer and use it in GitHub Desktop.
Save jaanli/c42db1454f540171516ffae08e8d9454 to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
import time
from flash_attn import flash_attn_qkvpacked_func
# Ensure torch supports bf16 on the current device; CUDA support for bf16 might be device-dependent
assert torch.cuda.is_bf16_supported(), "CUDA device does not support bf16."
# Data Preparation
batch_size = 1
seqlen = 16384
nheads = 8
headdim = 64
# Convert input data to bf16
qkv = torch.rand(batch_size, seqlen, 3, nheads, headdim, dtype=torch.float32, device='cuda', requires_grad=True)
qkv = qkv.to(torch.bfloat16)
qkv.retain_grad() # Ensure gradients are retained for qkv
q, k, v = qkv.split(1, dim=2) # Split the qkv tensor into q, k, v components
q, k, v = [x.squeeze(2).clone().detach().requires_grad_(True) for x in (q, k, v)] # Remove the extra dimension and prepare for gradient
# Assuming no attention mask is used for simplicity
mask = None
# Forward pass benchmarking
# Flash Attention
start_time = time.time()
flash_output = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=1.0/(headdim ** 0.5), causal=False, deterministic=False)
flash_forward_duration = time.time() - start_time
# PyTorch Scaled Dot Product Attention
q_pt, k_pt, v_pt = [x.contiguous().view(batch_size * nheads, seqlen, headdim) for x in (q, k, v)]
start_time = time.time()
pytorch_output = F.scaled_dot_product_attention(q_pt, k_pt, v_pt, attn_mask=mask, dropout_p=0.0)
pytorch_forward_duration = time.time() - start_time
# Convert PyTorch output back to the original shape and data type for comparison
pytorch_output = pytorch_output.view(batch_size, seqlen, nheads, headdim).to(torch.bfloat16)
# Backward pass benchmarking
# Define a simple loss function
loss_fn = torch.nn.MSELoss()
# Calculate loss
target = torch.randn_like(flash_output)
flash_loss = loss_fn(flash_output, target)
pytorch_loss = loss_fn(pytorch_output, target)
# Flash Attention backward
start_time = time.time()
flash_loss.backward()
flash_backward_duration = time.time() - start_time
# PyTorch Scaled Dot Product Attention backward
# Clear gradients before the next backward pass
q.grad = None
k.grad = None
v.grad = None
start_time = time.time()
pytorch_loss.backward()
pytorch_backward_duration = time.time() - start_time
# Combine gradients of q, k, v for PyTorch Attention to match the structure of qkv for comparison
combined_pytorch_grads = torch.cat([grad.contiguous().view(batch_size, seqlen, 1, nheads, headdim) for grad in (q.grad, k.grad, v.grad)], dim=2).to(torch.bfloat16)
# Outputs
print(f"Flash Attention Forward Duration: {flash_forward_duration:.6f} seconds")
print(f"PyTorch Scaled Dot Product Attention Forward Duration: {pytorch_forward_duration:.6f} seconds")
print(f"Flash Attention Backward Duration: {flash_backward_duration:.6f} seconds")
print(f"PyTorch Scaled Dot Product Attention Backward Duration: {pytorch_backward_duration:.6f} seconds")
print("Flash Attention Output (sample):", flash_output[0, 0, 0]) # Print a sample of the output
print("PyTorch Scaled Dot Product Attention Output (sample):", pytorch_output[0, 0, 0]) # Print a sample of the output
# Assert outputs are within tolerance
output_tolerance = 0.2 # Adjust tolerance as needed
assert torch.allclose(flash_output, pytorch_output, atol=output_tolerance), "Outputs are not within the specified tolerance."
print("All output elements are within the specified tolerance.")
# Assert gradient outputs are within tolerance
grad_tolerance = 1e-6 # Adjust tolerance for gradients as needed
assert torch.allclose(qkv.grad, combined_pytorch_grads, atol=grad_tolerance), "Gradient outputs are not within the specified tolerance."
print("All gradient elements are within the specified tolerance.")
print("Flash Attention Grad (sample):", combined_pytorch_grads[0, 0, 0, 0]) # Print a sample of the output
print("PyTorch Scaled Dot Product Attention Grad (sample):", qkv.grad[0, 0, 0, 0]) # Print a sample of the output
@jaanli
Copy link
Author

jaanli commented Feb 1, 2024

Example output on A10 GPU:

python test_flash_bw.py
Flash Attention Forward Duration: 0.002987 seconds
PyTorch Scaled Dot Product Attention Forward Duration: 0.090282 seconds
Flash Attention Backward Duration: 0.015947 seconds
PyTorch Scaled Dot Product Attention Backward Duration: 0.009706 seconds
Flash Attention Output (sample): tensor([0.5000, 0.5000, 0.5000, 0.4980, 0.4980, 0.5000, 0.5039, 0.5000, 0.5000,
        0.4980, 0.5000, 0.4961, 0.5000, 0.5000, 0.5000, 0.4980, 0.5000, 0.5000,
        0.5000, 0.5039, 0.5000, 0.5000, 0.5000, 0.5000, 0.5039, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5039, 0.5000, 0.5000, 0.5039, 0.5039, 0.5000, 0.5000,
        0.5000, 0.4980, 0.4980, 0.5039, 0.5000, 0.4980, 0.4980, 0.4980, 0.5000,
        0.5039, 0.4980, 0.5000, 0.4961, 0.5000, 0.4980, 0.5000, 0.5000, 0.4980,
        0.5000, 0.5039, 0.5000, 0.5000, 0.5039, 0.5039, 0.5000, 0.4980, 0.5000,
        0.5000], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SelectBackward0>)
PyTorch Scaled Dot Product Attention Output (sample): tensor([0.4980, 0.5000, 0.5000, 0.5000, 0.5039, 0.5000, 0.5039, 0.5000, 0.4980,
        0.5000, 0.5000, 0.5039, 0.5039, 0.4961, 0.5000, 0.4980, 0.5000, 0.4961,
        0.4980, 0.5039, 0.4980, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5039, 0.5000, 0.5000, 0.5000, 0.5039, 0.5039, 0.5000,
        0.5039, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.4961,
        0.4980, 0.5000, 0.4961, 0.4961, 0.5000, 0.4980, 0.4941, 0.5039, 0.5000,
        0.5000, 0.4980, 0.5000, 0.5000, 0.5039, 0.5039, 0.4980, 0.5000, 0.5039,
        0.5000], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SelectBackward0>)
All output elements are within the specified tolerance.
All gradient elements are within the specified tolerance.
Flash Attention Grad (sample): tensor([-3.1650e-10, -3.5925e-11, -2.1828e-10, -4.3656e-10,  2.8740e-10,
        -5.4342e-11,  2.4784e-11,  4.6612e-11, -3.0013e-10, -2.8558e-10,
        -3.0741e-10,  1.7008e-10, -1.6371e-10,  2.8922e-10, -3.1650e-10,
        -4.0927e-10,  3.1605e-11, -2.4784e-11, -4.8658e-11, -3.8017e-10,
         5.4570e-11, -7.1850e-11,  1.0141e-10, -1.2187e-10, -1.2005e-10,
        -3.4925e-10, -1.8372e-10,  2.8422e-11, -4.4338e-11, -5.1296e-10,
        -5.0568e-10, -3.4743e-10, -6.5484e-10,  5.5252e-11, -2.6921e-10,
        -5.7526e-11,  8.1400e-11,  1.1505e-10, -8.8676e-11, -1.3370e-10,
         4.8885e-11, -1.6735e-10, -1.6371e-10, -5.1159e-11, -2.9286e-10,
        -3.9654e-10, -7.6398e-11, -1.7917e-10,  9.5042e-11,  1.9463e-10,
        -1.3245e-11,  4.4020e-10, -4.5839e-10,  3.7744e-11, -7.1395e-11,
         7.4124e-11, -8.2309e-11, -1.4779e-11, -1.8554e-10, -2.8558e-10,
         1.7280e-10, -7.6852e-11, -6.3665e-11, -3.8017e-10], device='cuda:0',
       dtype=torch.bfloat16)
PyTorch Scaled Dot Product Attention Grad (sample): tensor([ 2.0373e-10, -1.5098e-10,  6.6393e-11, -1.0232e-10, -1.3188e-10,
         1.7099e-10,  9.5497e-11,  2.2919e-10,  3.1469e-10,  3.0013e-10,
         1.3824e-10,  2.9104e-10, -9.8225e-11,  8.2309e-11, -3.7835e-10,
        -4.7748e-11, -1.7644e-10,  1.5802e-11,  5.6843e-11,  1.8645e-10,
        -2.3829e-10,  1.0687e-10,  2.8376e-10,  6.7303e-11,  2.2101e-10,
         2.4556e-10,  1.8827e-10, -5.3205e-11, -3.0923e-11, -2.1487e-11,
        -8.9130e-11, -3.9836e-10, -1.1710e-11, -7.5488e-11,  7.0031e-11,
        -6.0481e-11,  6.2300e-11,  2.7649e-10,  6.0027e-11,  7.7307e-12,
         1.2005e-10, -2.3647e-10,  1.7167e-11, -1.3824e-10,  2.7649e-10,
         4.6566e-10,  2.2101e-10,  1.8099e-10,  2.7285e-10,  3.8744e-10,
         1.1642e-10,  2.7467e-10, -2.4011e-10, -1.3006e-10,  7.8217e-11,
         4.3838e-10,  1.4097e-10, -7.0486e-11,  2.8194e-10,  1.2642e-10,
         2.4738e-10, -2.7831e-10,  1.3279e-10, -7.1395e-11], device='cuda:0',
       dtype=torch.bfloat16)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment