Created
January 3, 2021 08:54
-
-
Save FrancescoSaverioZuppichini/3cc7a3655283b3304fb9b0f63af1a1e3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MultiHeadAttention(nn.Module): | |
def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0): | |
super().__init__() | |
self.emb_size = emb_size | |
self.num_heads = num_heads | |
# fuse the queries, keys and values in one matrix | |
self.qkv = nn.Linear(emb_size, emb_size * 3) | |
self.att_drop = nn.Dropout(dropout) | |
self.projection = nn.Linear(emb_size, emb_size) | |
def forward(self, x : Tensor, mask: Tensor = None) -> Tensor: | |
# split keys, queries and values in num_heads | |
qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3) | |
queries, keys, values = qkv[0], qkv[1], qkv[2] | |
# sum up over the last axis | |
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len | |
if mask is not None: | |
fill_value = torch.finfo(torch.float32).min | |
energy.mask_fill(~mask, fill_value) | |
scaling = self.emb_size ** (1/2) | |
att = F.softmax(energy, dim=-1) / scaling | |
att = self.att_drop(att) | |
# sum up over the third axis | |
out = torch.einsum('bhal, bhlv -> bhav ', att, values) | |
out = rearrange(out, "b h n d -> b n (h d)") | |
out = self.projection(out) | |
return out | |
patches_embedded = PatchEmbedding()(x) | |
MultiHeadAttention()(patches_embedded).shape |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment