Last active
April 15, 2024 10:10
-
-
Save hotbaby/0202783d3fcdd5e52b3509a6e3af4bb6 to your computer and use it in GitHub Desktop.
PyTorch多头自注意力机制
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# encoding: utf8 | |
import math | |
import torch | |
import torch.nn as nn | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, n_feat, n_head=4): | |
super().__init__() | |
self.n_head = 4 | |
self.n_dim = n_feat // n_head | |
self.linear_q = nn.Linear(n_feat, n_feat) | |
self.linear_k = nn.Linear(n_feat, n_feat) | |
self.linear_v = nn.Linear(n_feat, n_feat) | |
self.out = nn.Linear(n_feat, n_feat) | |
def forward(self, xs): | |
""" | |
前向传播 | |
param xs: (batch, seq_len, feat_len) | |
""" | |
q = self.linear_q(xs) | |
k = self.linear_k(xs) | |
v = self.linear_v(xs) | |
batch = xs.size()[0] | |
q = q.view(batch, -1, self.n_head, self.n_dim).transpose(1, 2) # (batch, n_head, seq_len, n_dim) | |
k = k.view(batch, -1, self.n_head, self.n_dim).transpose(1, 2) | |
v = v.view(batch, -1, self.n_head, self.n_dim).transpose(1, 2) | |
atten = torch.matmul(q, k.transpose(-1, -2)) # (batch, n_head, seq_len, n_dim) | |
atten_score = torch.softmax(atten, dim=-1) / math.sqrt(self.n_dim) # (batch, n_head, seq_len, seq_len) | |
o = torch.matmul(atten_score, v) # (batch, n_head, seq_len, n_dim) | |
o = o.transpose(1, 2).contiguous().view(batch, -1, self.n_dim*self.n_head) # (batch, seq_len, n_feat) | |
return self.out(o) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
References