Skip to content

Instantly share code, notes, and snippets.

@jenkspt
Last active July 21, 2022 19:13
Show Gist options
  • Save jenkspt/3a09cc150ab531781c6084c166047639 to your computer and use it in GitHub Desktop.
Save jenkspt/3a09cc150ab531781c6084c166047639 to your computer and use it in GitHub Desktop.
Demonstrate fix and parity of CLIP AttentionPool2d
"""
This gist demonstrates the equivalence between the existing CLIP `AttentionPool2d`
and the proposed `AttentionPool2dFix`, which only computes attention where needed.
"""
import torch
from torch import nn
import torch.nn.functional as F
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
# x has shape [(HW+1), B, C]
print("X shape after attention:", x.shape)
return x[0]
class AttentionPool2dFix(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
if __name__ == "__main__":
batch_dim = 5
spacial_dim = 8
embed_dim = 16
num_heads = 4
output_dim = 2
x = torch.randn(batch_dim, embed_dim, spacial_dim, spacial_dim)
pool1 = AttentionPool2d(spacial_dim, embed_dim, num_heads, output_dim)
y1 = pool1(x)
assert y1.shape == (batch_dim, output_dim)
pool2 = AttentionPool2dFix(spacial_dim, embed_dim, num_heads, output_dim)
# Make sure parameter state is the same
pool2.load_state_dict(pool1.state_dict())
y2 = pool2(x)
assert y2.shape == (batch_dim, output_dim)
assert torch.allclose(y1, y2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment