Created February 18, 2024 18:10
Annotated annotation module for CLIP
""" Annotated code for Transformers' CLIP implementation """
from typing import Optional, Tuple
import torch
from torch import nn
class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Recall the formula is softmax(QK^T/√d_k)V. Note CLIP has a max sequence length of 77 in the original implementation so I use this as examples.
def __init__(self, config):
self.config = config
self.embed_dim = config.hidden_size # 512 in the original paper
self.num_heads = config.num_attention_heads # 8 heads in the original paper
self.head_dim = self.embed_dim // self.num_heads # 64 it in the original paper
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads}).")
self.scale = self.head_dim**-0.5 # This is == 1/√(d_k) in the paper
self.dropout = config.attention_dropout
# Below are single neural networks layers acting as the projection matrices W_k, W_q, W_v. In the paper this have dimensions
# (embed_dim, head_dim) = (512, 64) and there are num_heads = 8 of them. In this implementation they are combined into a single layer so single
# matrix of size (embed_dim, embed_dim) = (512, 512)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
# The output projection W_o
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
# This function takes in a tensor of size (bs, seq_len, embed_dim) and first turns it into a tensor of size
# (bs, seq_length, num_heads, head_dim) and then transposes it into a shape (bs, num_heads, seq_length, head_dim)
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
hidden_states: torch.Tensor, # of shape (batch_size, max_seq_length, embedding_dim)
attention_mask: Optional[torch.Tensor] = None, # when you have padding tokens you want to ignore them from the attention calculation
# See for details.
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
# Tgt_len is the max sequence length in the batch, also for CLIP this cannot exceed 77. For the purpose of illustration I will use 77 in some
# examples below.
# Made a slight modification in the code below to make it more readable. We apply the projections to get the Q, K and V matrices out and use
# the _shape method to put them into the shape (bs, num_heads, seq_length, head_dim) = (bs, 8, 77, 64). So now we see that the linear
# layer stores all the num_heads matrices in one layer but now we extract the individual matrices for each head to apply attention per each
# head!
query_states = self._shape(self.q_proj(hidden_states) * self.scale, tgt_len, bsz) # also divides by √d_k
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
# The below combines the bs and num_heads dimensions turning each matrix (K, Q, V) from size (bs, num_hads, seq_length, head_dim) into
# (bs * num_heads, seq_length, head_dim). Essentially stacking the matrices for all batches into one. This is done to be able to use the
# torch.bmm function which expects 3D tensors.
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = query_states.view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
# bmm performs a batched matrix multiplication, this implements QK^T. Matrix multiply (bs * num_heads, seq_length, head_dim) with
# (bs * num_heads, head_dim, seq_length) resulting in an output shape (bs * num_heads, seq_length, seq_length).
# I.e. for each text in the batch for each head compute an n^2 attention between each token in the sequence.
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}")
# apply the causal_attention_mask first
# I believe the causal_attention_mask = masked self attention as in the transformer decoder to preserve the autoregressive property for the
# model. As explained in the CLIP paper: Masked self-attention was used in the text encoder to preserve the ability to initialize
# with a pre-trained language model or add language modeling as an auxiliary objective, though exploration of this is left as future work.
# Why is it used? See:
if causal_attention_mask is not None:
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {causal_attention_mask.size()}")
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}")
# This applies the attention mask. The purpose is to not attend to padding tokens as they are only used for shorter sequences to
# pad them to the maximum sequence length in the batch to be able to store in one tensor. As a result the attention mask contains 0 for
# all useful tokens that should be attended to and is set to a very large negative number (lets say -inf) for padding tokens. This will
# then be ignored by the softmax (recall exp(-inf) = 0)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1) # Gives attn_weights = softmax(QK^T/√d_k). Dim=-1 performars the softmax on the
# last dimension (e.g. the rows in a 2D case) using negative indexing. The output shape remains (bs * num_heads, seq_length, seq_length)
if output_attentions: # no idea what this does but seems None by default #TODO
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, # Applies dropout but only in training
attn_output = torch.bmm(attn_probs, value_states)
# Peforms the final multiplication of the attn_weights with the value matrix V. So (bs * num_heads, seq_length, seq_length) *
# (bs * num_heads, seq_length, head_dim) which results in an output of shape (bs * num_heads, seq_length, head_dim)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}")
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
# expands it back out into a shape (bs, num_heads, seq_length, head_dim)
attn_output = attn_output.transpose(1, 2)
# transposes to (bs, seq_length, num_heads, head_dim)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
# Concatenates each matrix back into (bs, seq_length, num_heads * head_dim) = (bs, seq_length, embed_dim) = (bs, 77, 512)
attn_output = self.out_proj(attn_output) # Final multiplication with the output matrix W_o. Keeps the shape as (bs, seq_length, embed_dim)
return attn_output, attn_weights_reshaped
