Skip to content

Instantly share code, notes, and snippets.

@appliedml42
Last active January 7, 2022 15:39
Show Gist options
  • Save appliedml42/6ba42126a649dae22c833d275b1dca3d to your computer and use it in GitHub Desktop.
Save appliedml42/6ba42126a649dae22c833d275b1dca3d to your computer and use it in GitHub Desktop.
MultiHeadAttention Implementation using einops and Pytorch
'''
I am reading this amazing series(https://uvadlc-notebooks.readthedocs.io/en/latest). I always struggle with revisiting
my old code that has a lot of tensor manipulation. Experimented with reimplementing their MultiHeadAttention layer using
einops syntax that feels more human readable.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
import math
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, d_model):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.num_heads = num_heads
self.d_head = int(d_model / self.num_heads)
self.w_qkv = nn.Linear(d_model, 3 * d_model)
self.w_out = nn.Linear(d_model, d_model)
def _reset_parameters(self):
nn.init.xavier_uniform_(self.w_qkv.weight)
self.w_qkv.bias.data.fill_(0)
nn.init.xavier_uniform_(self.w_out.weight)
self.w_out.bias.data.fill_(0)
def forward(self, x, mask=None):
# Project into query, key, and value space in one shot.
qkv = self.w_qkv(x)
# Split into different heads.
qkv = einops.rearrange(qkv,
'batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim',
num_heads=self.num_heads,
head_dim=3 * self.d_head
)
# For each head split back into query, key, and value.
q, k, v = einops.rearrange(qkv,
'batch num_heads seq_len (split head_dim) -> split batch num_heads seq_len head_dim',
split=3)
# Reshape so we can do dot product with query. -1 dimension of query vector needs to match -2 dimension of key.
k = einops.rearrange(k,
'batch num_heads seq_len head_dim -> batch num_heads head_dim seq_len')
attention_logits = torch.matmul(q, k) / math.sqrt(q.size()[-1])
if mask is not None:
# Reshape so it work with the logits matrix
mask = einops.rearrange(mask,
'batch seq_len -> batch 1 1 seq_len')
attention_logits = attention_logits.masked_fill(mask == 0, -9e15)
attention_weights = F.softmax(attention_logits, dim=-1)
attention = torch.matmul(attention_weights, v)
# merge the heads.
attention = einops.rearrange(attention,
'batch num_heads seq_len head_dim -> batch seq_len (num_heads head_dim)')
return self.w_out(attention), attention_weights
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment