Created
February 1, 2024 22:50
-
-
Save jaanli/c42db1454f540171516ffae08e8d9454 to your computer and use it in GitHub Desktop.
Testing Flash Attention (https://github.com/Dao-AILab/flash-attention) against PyTorch's `F.scaled_dot_product_attention` (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
import time | |
from flash_attn import flash_attn_qkvpacked_func | |
# Ensure torch supports bf16 on the current device; CUDA support for bf16 might be device-dependent | |
assert torch.cuda.is_bf16_supported(), "CUDA device does not support bf16." | |
# Data Preparation | |
batch_size = 1 | |
seqlen = 16384 | |
nheads = 8 | |
headdim = 64 | |
# Convert input data to bf16 | |
qkv = torch.rand(batch_size, seqlen, 3, nheads, headdim, dtype=torch.float32, device='cuda', requires_grad=True) | |
qkv = qkv.to(torch.bfloat16) | |
qkv.retain_grad() # Ensure gradients are retained for qkv | |
q, k, v = qkv.split(1, dim=2) # Split the qkv tensor into q, k, v components | |
q, k, v = [x.squeeze(2).clone().detach().requires_grad_(True) for x in (q, k, v)] # Remove the extra dimension and prepare for gradient | |
# Assuming no attention mask is used for simplicity | |
mask = None | |
# Forward pass benchmarking | |
# Flash Attention | |
start_time = time.time() | |
flash_output = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=1.0/(headdim ** 0.5), causal=False, deterministic=False) | |
flash_forward_duration = time.time() - start_time | |
# PyTorch Scaled Dot Product Attention | |
q_pt, k_pt, v_pt = [x.contiguous().view(batch_size * nheads, seqlen, headdim) for x in (q, k, v)] | |
start_time = time.time() | |
pytorch_output = F.scaled_dot_product_attention(q_pt, k_pt, v_pt, attn_mask=mask, dropout_p=0.0) | |
pytorch_forward_duration = time.time() - start_time | |
# Convert PyTorch output back to the original shape and data type for comparison | |
pytorch_output = pytorch_output.view(batch_size, seqlen, nheads, headdim).to(torch.bfloat16) | |
# Backward pass benchmarking | |
# Define a simple loss function | |
loss_fn = torch.nn.MSELoss() | |
# Calculate loss | |
target = torch.randn_like(flash_output) | |
flash_loss = loss_fn(flash_output, target) | |
pytorch_loss = loss_fn(pytorch_output, target) | |
# Flash Attention backward | |
start_time = time.time() | |
flash_loss.backward() | |
flash_backward_duration = time.time() - start_time | |
# PyTorch Scaled Dot Product Attention backward | |
# Clear gradients before the next backward pass | |
q.grad = None | |
k.grad = None | |
v.grad = None | |
start_time = time.time() | |
pytorch_loss.backward() | |
pytorch_backward_duration = time.time() - start_time | |
# Combine gradients of q, k, v for PyTorch Attention to match the structure of qkv for comparison | |
combined_pytorch_grads = torch.cat([grad.contiguous().view(batch_size, seqlen, 1, nheads, headdim) for grad in (q.grad, k.grad, v.grad)], dim=2).to(torch.bfloat16) | |
# Outputs | |
print(f"Flash Attention Forward Duration: {flash_forward_duration:.6f} seconds") | |
print(f"PyTorch Scaled Dot Product Attention Forward Duration: {pytorch_forward_duration:.6f} seconds") | |
print(f"Flash Attention Backward Duration: {flash_backward_duration:.6f} seconds") | |
print(f"PyTorch Scaled Dot Product Attention Backward Duration: {pytorch_backward_duration:.6f} seconds") | |
print("Flash Attention Output (sample):", flash_output[0, 0, 0]) # Print a sample of the output | |
print("PyTorch Scaled Dot Product Attention Output (sample):", pytorch_output[0, 0, 0]) # Print a sample of the output | |
# Assert outputs are within tolerance | |
output_tolerance = 0.2 # Adjust tolerance as needed | |
assert torch.allclose(flash_output, pytorch_output, atol=output_tolerance), "Outputs are not within the specified tolerance." | |
print("All output elements are within the specified tolerance.") | |
# Assert gradient outputs are within tolerance | |
grad_tolerance = 1e-6 # Adjust tolerance for gradients as needed | |
assert torch.allclose(qkv.grad, combined_pytorch_grads, atol=grad_tolerance), "Gradient outputs are not within the specified tolerance." | |
print("All gradient elements are within the specified tolerance.") | |
print("Flash Attention Grad (sample):", combined_pytorch_grads[0, 0, 0, 0]) # Print a sample of the output | |
print("PyTorch Scaled Dot Product Attention Grad (sample):", qkv.grad[0, 0, 0, 0]) # Print a sample of the output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example output on A10 GPU: