Skip to content

Instantly share code, notes, and snippets.

@paluchasz
Created February 18, 2024 18:10
Show Gist options
  • Save paluchasz/d50b03b1d0ec99cfc4ea352446fd9790 to your computer and use it in GitHub Desktop.
Save paluchasz/d50b03b1d0ec99cfc4ea352446fd9790 to your computer and use it in GitHub Desktop.
Annotated annotation module for CLIP
""" Annotated code for Transformers' CLIP implementation """
from typing import Optional, Tuple
import torch
from torch import nn
class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Recall the formula is softmax(QK^T/√d_k)V. Note CLIP has a max sequence length of 77 in the original implementation so I use this as examples.
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size # 512 in the original paper
self.num_heads = config.num_attention_heads # 8 heads in the original paper
self.head_dim = self.embed_dim // self.num_heads # 64 it in the original paper
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads}).")
self.scale = self.head_dim**-0.5 # This is == 1/√(d_k) in the paper
self.dropout = config.attention_dropout
# Below are single neural networks layers acting as the projection matrices W_k, W_q, W_v. In the paper this have dimensions
# (embed_dim, head_dim) = (512, 64) and there are num_heads = 8 of them. In this implementation they are combined into a single layer so single
# matrix of size (embed_dim, embed_dim) = (512, 512)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
# The output projection W_o
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
# This function takes in a tensor of size (bs, seq_len, embed_dim) and first turns it into a tensor of size
# (bs, seq_length, num_heads, head_dim) and then transposes it into a shape (bs, num_heads, seq_length, head_dim)
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor, # of shape (batch_size, max_seq_length, embedding_dim)
attention_mask: Optional[torch.Tensor] = None, # when you have padding tokens you want to ignore them from the attention calculation
# See https://huggingface.co/docs/transformers/en/glossary for details.
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
# Tgt_len is the max sequence length in the batch, also for CLIP this cannot exceed 77. For the purpose of illustration I will use 77 in some
# examples below.
# Made a slight modification in the code below to make it more readable. We apply the projections to get the Q, K and V matrices out and use
# the _shape method to put them into the shape (bs, num_heads, seq_length, head_dim) = (bs, 8, 77, 64). So now we see that the linear
# layer stores all the num_heads matrices in one layer but now we extract the individual matrices for each head to apply attention per each
# head!
query_states = self._shape(self.q_proj(hidden_states) * self.scale, tgt_len, bsz) # also divides by √d_k
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
# The below combines the bs and num_heads dimensions turning each matrix (K, Q, V) from size (bs, num_hads, seq_length, head_dim) into
# (bs * num_heads, seq_length, head_dim). Essentially stacking the matrices for all batches into one. This is done to be able to use the
# torch.bmm function https://pytorch.org/docs/stable/generated/torch.bmm.html which expects 3D tensors.
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = query_states.view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
# bmm performs a batched matrix multiplication, this implements QK^T. Matrix multiply (bs * num_heads, seq_length, head_dim) with
# (bs * num_heads, head_dim, seq_length) resulting in an output shape (bs * num_heads, seq_length, seq_length).
# I.e. for each text in the batch for each head compute an n^2 attention between each token in the sequence.
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}")
# apply the causal_attention_mask first
# I believe the causal_attention_mask = masked self attention as in the transformer decoder to preserve the autoregressive property for the
# model. As explained in the CLIP paper: Masked self-attention was used in the text encoder to preserve the ability to initialize
# with a pre-trained language model or add language modeling as an auxiliary objective, though exploration of this is left as future work.
# Why is it used? See: https://ai.stackexchange.com/questions/40917/what-if-we-drop-the-causal-mask-in-auto-regressive-transformer
if causal_attention_mask is not None:
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {causal_attention_mask.size()}")
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}")
# This applies the attention mask. The purpose is to not attend to padding tokens as they are only used for shorter sequences to
# pad them to the maximum sequence length in the batch to be able to store in one tensor. As a result the attention mask contains 0 for
# all useful tokens that should be attended to and is set to a very large negative number (lets say -inf) for padding tokens. This will
# then be ignored by the softmax (recall exp(-inf) = 0)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1) # Gives attn_weights = softmax(QK^T/√d_k). Dim=-1 performars the softmax on the
# last dimension (e.g. the rows in a 2D case) using negative indexing. The output shape remains (bs * num_heads, seq_length, seq_length)
if output_attentions: # no idea what this does but seems None by default #TODO
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Applies dropout but only in training
attn_output = torch.bmm(attn_probs, value_states)
# Peforms the final multiplication of the attn_weights with the value matrix V. So (bs * num_heads, seq_length, seq_length) *
# (bs * num_heads, seq_length, head_dim) which results in an output of shape (bs * num_heads, seq_length, head_dim)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}")
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
# expands it back out into a shape (bs, num_heads, seq_length, head_dim)
attn_output = attn_output.transpose(1, 2)
# transposes to (bs, seq_length, num_heads, head_dim)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
# Concatenates each matrix back into (bs, seq_length, num_heads * head_dim) = (bs, seq_length, embed_dim) = (bs, 77, 512)
attn_output = self.out_proj(attn_output) # Final multiplication with the output matrix W_o. Keeps the shape as (bs, seq_length, embed_dim)
return attn_output, attn_weights_reshaped
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment