Skip to content

Instantly share code, notes, and snippets.

@wolfecameron
Last active May 22, 2024 05:48
Show Gist options
  • Save wolfecameron/26863dbbc322b15d2e224a2569868256 to your computer and use it in GitHub Desktop.
Save wolfecameron/26863dbbc322b15d2e224a2569868256 to your computer and use it in GitHub Desktop.
Implementation of causal self-attention in PyTorch
"""
Source: https://github.com/karpathy/nanoGPT/blob/master/model.py
"""
import math
import torch
from torch import nn
import torch.nn.functional as F
class CausalSelfAttention(nn.Module):
def __init__(
self,
d,
H,
T,
bias=False,
dropout=0.2,
):
"""
Arguments:
d: size of embedding dimension
H: number of attention heads
T: maximum length of input sequences (in tokens)
bias: whether or not to use bias in linear layers
dropout: probability of dropout
"""
super().__init__()
assert d % H == 0
# key, query, value projections for all heads, but in a batch
# output is 3X the dimension because it includes key, query and value
self.c_attn = nn.Linear(d, 3*d, bias=bias)
# projection of concatenated attention head outputs
self.c_proj = nn.Linear(d, d, bias=bias)
# dropout modules
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
self.H = H
self.d = d
# causal mask to ensure that attention is only applied to
# the left in the input sequence
self.register_buffer("mask", torch.tril(torch.ones(T, T))
.view(1, 1, T, T))
def forward(self, x):
B, T, _ = x.size() # batch size, sequence length, embedding dimensionality
# compute query, key, and value vectors for all heads in batch
# split the output into separate query, key, and value tensors
q, k, v = self.c_attn(x).split(self.d, dim=2) # [B, T, d]
# reshape tensor into sequences of smaller token vectors for each head
k = k.view(B, T, self.H, self.d // self.H).transpose(1, 2) # [B, H, T, d // H]
q = q.view(B, T, self.H, self.d // self.H).transpose(1, 2)
v = v.view(B, T, self.H, self.d // self.H).transpose(1, 2)
# compute the attention matrix, perform masking, and apply dropout
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # [B, H, T, T]
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
# compute output vectors for each token
y = att @ v # [B, H, T, d // H]
# concatenate outputs from each attention head and linearly project
y = y.transpose(1, 2).contiguous().view(B, T, self.d)
y = self.resid_dropout(self.c_proj(y))
return y
@meshed-potato
Copy link

Line 63 should be self.mask[:,:,:T,:T], instead of self.bias

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment