Skip to content

Instantly share code, notes, and snippets.

@weilueluo
Created April 10, 2021 22:22
Show Gist options
  • Save weilueluo/39e13939270e546944c98e872120c219 to your computer and use it in GitHub Desktop.
Save weilueluo/39e13939270e546944c98e872120c219 to your computer and use it in GitHub Desktop.
minimum multi-head attention implementation
import torch
from torch import nn
class MultiHeadAttention(nn.Module):
def __init__(self, in_size, n_heads=1, scaled=True):
super().__init__()
# in_size = d_k in the paper
self.scale = in_size ** 0.5 if scaled else 1
self.n_heads = n_heads
self.q_linear = nn.Linear(in_size, in_size * n_heads)
self.k_linear = nn.Linear(in_size, in_size * n_heads)
self.v_linear = nn.Linear(in_size, in_size * n_heads)
self.o_linear = nn.Linear(in_size * n_heads, in_size)
def forward(self, x):
batch_size, seq_len, in_size = x.shape
# projection
q = self.q_linear(x).reshape(batch_size, self.n_heads, seq_len, in_size)
k = self.k_linear(x).reshape(batch_size, self.n_heads, seq_len, in_size)
v = self.v_linear(x).reshape(batch_size, self.n_heads, seq_len, in_size)
# attention
attn = torch.matmul(q, k.transpose(2, 3)) / self.scale
score = torch.softmax(attn, dim=-1)
attn_out = torch.matmul(score, v)
# concat
concatenated = attn_out.transpose(1, 2).reshape(batch_size, seq_len, -1)
# projection
projected = self.o_linear(concatenated)
# residual
out = x + projected
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment