Skip to content

Instantly share code, notes, and snippets.

@hotbaby
Last active April 15, 2024 10:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hotbaby/0202783d3fcdd5e52b3509a6e3af4bb6 to your computer and use it in GitHub Desktop.
Save hotbaby/0202783d3fcdd5e52b3509a6e3af4bb6 to your computer and use it in GitHub Desktop.
PyTorch多头自注意力机制
# encoding: utf8
import math
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, n_feat, n_head=4):
super().__init__()
self.n_head = 4
self.n_dim = n_feat // n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.out = nn.Linear(n_feat, n_feat)
def forward(self, xs):
"""
前向传播
param xs: (batch, seq_len, feat_len)
"""
q = self.linear_q(xs)
k = self.linear_k(xs)
v = self.linear_v(xs)
batch = xs.size()[0]
q = q.view(batch, -1, self.n_head, self.n_dim).transpose(1, 2) # (batch, n_head, seq_len, n_dim)
k = k.view(batch, -1, self.n_head, self.n_dim).transpose(1, 2)
v = v.view(batch, -1, self.n_head, self.n_dim).transpose(1, 2)
atten = torch.matmul(q, k.transpose(-1, -2)) # (batch, n_head, seq_len, n_dim)
atten_score = torch.softmax(atten, dim=-1) / math.sqrt(self.n_dim) # (batch, n_head, seq_len, seq_len)
o = torch.matmul(atten_score, v) # (batch, n_head, seq_len, n_dim)
o = o.transpose(1, 2).contiguous().view(batch, -1, self.n_dim*self.n_head) # (batch, seq_len, n_feat)
return self.out(o)
@hotbaby
Copy link
Author

hotbaby commented Jan 17, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment