Last active
January 7, 2022 15:39
-
-
Save appliedml42/6ba42126a649dae22c833d275b1dca3d to your computer and use it in GitHub Desktop.
MultiHeadAttention Implementation using einops and Pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
I am reading this amazing series(https://uvadlc-notebooks.readthedocs.io/en/latest). I always struggle with revisiting | |
my old code that has a lot of tensor manipulation. Experimented with reimplementing their MultiHeadAttention layer using | |
einops syntax that feels more human readable. | |
''' | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import einops | |
import math | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, num_heads, d_model): | |
super(MultiHeadAttention, self).__init__() | |
assert d_model % num_heads == 0 | |
self.num_heads = num_heads | |
self.d_head = int(d_model / self.num_heads) | |
self.w_qkv = nn.Linear(d_model, 3 * d_model) | |
self.w_out = nn.Linear(d_model, d_model) | |
def _reset_parameters(self): | |
nn.init.xavier_uniform_(self.w_qkv.weight) | |
self.w_qkv.bias.data.fill_(0) | |
nn.init.xavier_uniform_(self.w_out.weight) | |
self.w_out.bias.data.fill_(0) | |
def forward(self, x, mask=None): | |
# Project into query, key, and value space in one shot. | |
qkv = self.w_qkv(x) | |
# Split into different heads. | |
qkv = einops.rearrange(qkv, | |
'batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim', | |
num_heads=self.num_heads, | |
head_dim=3 * self.d_head | |
) | |
# For each head split back into query, key, and value. | |
q, k, v = einops.rearrange(qkv, | |
'batch num_heads seq_len (split head_dim) -> split batch num_heads seq_len head_dim', | |
split=3) | |
# Reshape so we can do dot product with query. -1 dimension of query vector needs to match -2 dimension of key. | |
k = einops.rearrange(k, | |
'batch num_heads seq_len head_dim -> batch num_heads head_dim seq_len') | |
attention_logits = torch.matmul(q, k) / math.sqrt(q.size()[-1]) | |
if mask is not None: | |
# Reshape so it work with the logits matrix | |
mask = einops.rearrange(mask, | |
'batch seq_len -> batch 1 1 seq_len') | |
attention_logits = attention_logits.masked_fill(mask == 0, -9e15) | |
attention_weights = F.softmax(attention_logits, dim=-1) | |
attention = torch.matmul(attention_weights, v) | |
# merge the heads. | |
attention = einops.rearrange(attention, | |
'batch num_heads seq_len head_dim -> batch seq_len (num_heads head_dim)') | |
return self.w_out(attention), attention_weights |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment