Skip to content

Instantly share code, notes, and snippets.

@honglu2875
Created January 27, 2023 23:56
Show Gist options
  • Save honglu2875/c94f01c22db9be2c0a72cd54d05b7f24 to your computer and use it in GitHub Desktop.
Save honglu2875/c94f01c22db9be2c0a72cd54d05b7f24 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
import time
bias = torch.tril(torch.ones((2048, 2048), dtype=torch.uint8, device='cuda')).view(
1, 1, 2048, 2048
)
def _attn(
query,
key,
value,
attention_mask=None,
head_mask=None,
original=True,
):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
if original:
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
attn_weights = attn_weights / 16
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
key = torch.rand((8,1024,1024), dtype=torch.float32, device='cuda')
query = torch.rand((8,1024,1024), dtype=torch.float32, device='cuda')
value = torch.rand((8,1024,1024), dtype=torch.float32, device='cuda')
start = time.perf_counter()
for _ in range(100):
_attn(key, query, value, original=True)
print(time.perf_counter() - start)
start = time.perf_counter()
for _ in range(100):
_attn(key, query, value, original=False)
print(time.perf_counter() - start)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment