Skip to content

Instantly share code, notes, and snippets.

@saurabh-kataria
Created July 9, 2024 14:04
Show Gist options
  • Save saurabh-kataria/77cee5ffa8fc6b3008bed57d52da1b05 to your computer and use it in GitHub Desktop.
Save saurabh-kataria/77cee5ffa8fc6b3008bed57d52da1b05 to your computer and use it in GitHub Desktop.
#######################
# CODE BASED ON https://github.com/hyunwoongko/transformer/blob/master/README.md
#######################
import torch
import torch.nn as nn
import math
from torch.cuda.amp import autocast
import torch.nn.functional as F
from rotary_embedding_torch import RotaryEmbedding
#from torchtune.modules import RotaryPositionalEmbeddings as RotaryEmbedding
import math
from functools import partial
import torch
import torch.nn as nn
from einops import rearrange, repeat
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
from flash_attn.utils.distributed import get_dim_for_local_rank
try:
from flash_attn import (
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
)
except ImportError:
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
flash_attn_with_kvcache = None
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
except ImportError:
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
#try:
# from flash_attn.layers.rotary import RotaryEmbedding
#except ImportError:
# RotaryEmbedding = None
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
def get_alibi_slopes(nheads):
def get_slopes_power_of_2(nheads):
start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
ratio = start
return [start * ratio**i for i in range(nheads)]
if math.log2(nheads).is_integer():
return get_slopes_power_of_2(nheads)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_head):
super().__init__()
self.n_head = n_head
self.attention = Attention(d_model // n_head)
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_concat = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
# 1. dot product with weight matrices
q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
# 2. split tensor by number of heads
q, k, v = self.split(q), self.split(k), self.split(v)
# 3. do scale dot product to compute similarity
out, attention = self.attention(q, k, v, mask)
# 4. concat and pass to linear layer
out = self.concat(out)
out = self.w_concat(out)
return out
def split(self, tensor):
"""
split tensor by number of head
:param tensor: [batch_size, length, d_model]
:return: [batch_size, head, length, d_tensor]
"""
batch_size, length, d_model = tensor.size()
d_tensor = d_model // self.n_head
tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2)
return tensor
@staticmethod
def concat(tensor):
"""
inverse function of self.split(tensor : torch.Tensor)
:param tensor: [batch_size, head, length, d_tensor]
:return: [batch_size, length, d_model]
"""
batch_size, head, length, d_tensor = tensor.size()
d_model = head * d_tensor
tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
return tensor
## 147hrs per epoch 1 B model 4 GPUs - definitely faster on all types than vanilla O(n^3) matmul att calc
#class Attention(nn.Module):
# def __init__(self, d_head, dropout=0.1):
# super().__init__()
# self.dropout = torch.nn.Dropout(dropout)
# self.softmax = nn.Softmax(dim=-1)
# self.rotary_embed = RotaryEmbedding(d_head//2)
# self.first = True
# self.dropout_value = dropout
#
# def forward(self, q, k, v, mask=None):
# # apply RoPE
# q = self.rotary_embed.rotate_queries_or_keys(q)
# k = self.rotary_embed.rotate_queries_or_keys(k)
#
# d_k = k.size(-1)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
# scores = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout_value if self.training else 0, is_causal=True if mask is not None else False)
#
# return scores, None
# cant conclude this yet - for 1B, it is faster
class Attention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(
self,
d_head,
dropout=0.1,
causal=False,
softmax_scale=None,
attention_dropout=0.1,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
):
super().__init__()
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
self.window_size = window_size
self.deterministic = deterministic
self.rotary_embed = RotaryEmbedding(d_head//2)
def forward(self, q, k, v, mask=None, causal=None, cu_seqlens=None, max_seqlen=None):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value.
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
Returns:
--------
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
else (B, S, H, D).
"""
# q.shape = B, H, S, D
causal = True if mask is not None else False
## with torch.cuda.amp.autocast():
q = self.rotary_embed.rotate_queries_or_keys(q)
k = self.rotary_embed.rotate_queries_or_keys(k)
## print(q.shape, k.shape) # torch.Size([1024, 16, 31, 64])
## q = q.to(v.dtype)
## k = k.to(v.dtype)
## q = self.rotary_embed(q)
## k = self.rotary_embed(k)
qkv = torch.concatenate([q.transpose(1,2).unsqueeze(2), k.transpose(1,2).unsqueeze(2), v.transpose(1,2).unsqueeze(2)], axis=2)
assert qkv.dtype in [torch.float16, torch.bfloat16], f'{type(qkv)=}'
assert qkv.is_cuda
causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None
if self.alibi_slopes is not None:
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
if unpadded:
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None
assert isinstance(max_seqlen, int)
out = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale, deterministic=self.deterministic,
)
else:
out = flash_attn_qkvpacked_func(
qkv,
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
deterministic=self.deterministic,
)
return out.transpose(1,2), None
#class Attention(nn.Module):
# def __init__(self, d_head, dropout=0.1):
# super().__init__()
# self.dropout = torch.nn.Dropout(dropout)
# self.softmax = nn.Softmax(dim=-1)
# self.rotary_embed = RotaryEmbedding(d_head//2)
# self.first = True
#
# def forward(self, q, k, v, mask=None):
# # apply RoPE
# q = self.rotary_embed.rotate_queries_or_keys(q)
# k = self.rotary_embed.rotate_queries_or_keys(k)
#
# d_k = k.size(-1)
## if self.first:
## print(f'{q.shape=} {k.shape=} {v.shape=}')
## self.first = False
#
## scores = flash_attn_func(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2))
## F.scaled_dot_product_attention(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2))
# scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
#
## raise Exception(f'{q.shape=} {k.shape=} {v.shape=} {scores.shape=} {mask.shape=} {mask=}')
# if mask is not None:
# #mask_value = 1e-9 if scores.dtype == torch.float32 else -1e4
# scores = scores.masked_fill(mask == 0, -torch.inf)
#
# p_attn = self.softmax(scores)
# p_attn = self.dropout(p_attn)
# scores = torch.matmul(p_attn, v)
# #raise Exception(f'{q.shape=} {k.shape=} {v.shape=} {scores.shape=} {mask.shape=} {mask=}')
# # q.shape=torch.Size([128, 20, 31, 80]) k.shape=torch.Size([128, 20, 31, 80]) v.shape=torch.Size([128, 20, 31, 80]) scores.shape=torch.Size([128, 20, 31, 80]) mask.shape=torch.Size([31, 31]) mask=tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
# return scores, p_attn
## DISREGARD BELOW VERSIONS
#class Attention(nn.Module):
# def __init__(self, d_head, dropout=0.1):
# super().__init__()
# self.dropout = torch.nn.Dropout(dropout)
# self.softmax = nn.Softmax(dim=-1)
# self.rotary_embed = RotaryEmbedding(d_head//2)
# self.first = True
# self.dropout_value = dropout
#
# def forward(self, q, k, v, mask=None):
# # apply RoPE
# q = self.rotary_embed.rotate_queries_or_keys(q)
# k = self.rotary_embed.rotate_queries_or_keys(k)
#
# d_k = k.size(-1)
# raise Exception(f'{type(q)=} {type(k)=} {type(v)=}')
# scores = flash_attn_func(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), dropout_p=self.dropout_value, causal=True if mask is not None else False).transpose(1,2)
#
# return scores, None
#class Attention2(nn.Module):
# def __init__(self, d_head, dropout=0.1):
# super().__init__()
# self.dropout = torch.nn.Dropout(dropout)
# self.softmax = nn.Softmax(dim=-1)
# self.rotary_embed = RotaryEmbedding(d_head//2) # Assuming this is a defined class elsewhere
# self.first = True
#
# def forward(self, q, k, v, mask=None):
# with autocast(enabled=True, dtype=torch.float16):
# # Apply RoPE (Rotary Positional Embedding)
# q = self.rotary_embed.rotate_queries_or_keys(q)
# k = self.rotary_embed.rotate_queries_or_keys(k)
#
# # Transpose the tensors to match FlashAttention's expected input shape
# q_transposed = q.transpose(1, 2).to(dtype=torch.float16)
# k_transposed = k.transpose(1, 2).to(dtype=torch.float16)
# v_transposed = v.transpose(1, 2).to(dtype=torch.float16)
#
# if self.first:
# print(f'{q_transposed.shape=} {k_transposed.shape=} {v_transposed.shape=} {mask.shape=}')
#
# # Pass the transposed and casted tensors to FlashAttention function
# scores = flash_attn_func(q_transposed, k_transposed, v_transposed, dropout_p=self.dropout if self.training else 0, causal=True if mask is None else False)
#
# if self.first:
# print(f'{q_transposed.shape=} {k_transposed.shape=} {v_transposed.shape=} {scores.shape=} {mask.shape=}')
# self.first = False
#
## # You may want to keep the softmax operation in fp32 for numerical stability
## if mask is not None:
## scores = scores.float() # Convert back to fp32 if necessary
## scores = scores.masked_fill(mask == 0, float('-inf'))
#
## if mask is not None:
## # Ensure mask is broadcastable to the size of scores.
## # This might involve unsqueezing dimensions or ensuring it has the right shape.
## # For example, if mask should cover the sequence length which is the last dimension:
## mask = mask.unsqueeze(1).unsqueeze(2) # Adding dimensions to match scores shape
## # The mask needs to be the same dtype as scores, and usually, the mask is not in fp16.
## # Convert mask to the same dtype as scores, if necessary
## mask = mask.to(dtype=scores.dtype)
## # Use broadcasting to apply the mask
## scores = scores.masked_fill(mask == 0, float('-inf'))
##
## p_attn = self.softmax(scores)
## p_attn = self.dropout(p_attn)
##
## # Ensure that 'v' is in the correct dtype before matmul if needed
## v = v.to(dtype=scores.dtype)
##
## return torch.matmul(p_attn, v), p_attn
# return scores, None
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, hidden, drop_prob=0.1):
super(PositionwiseFeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, hidden)
self.linear2 = nn.Linear(hidden, d_model)
self.relu = nn.ReLU() #inplace=True)
self.dropout = nn.Dropout(p=drop_prob)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.linear2(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment