Skip to content

Instantly share code, notes, and snippets.

@Morriaty-The-Murderer
Last active January 9, 2024 11:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Morriaty-The-Murderer/c621922429d1f6b4a91ed1fdc97798e3 to your computer and use it in GitHub Desktop.
Save Morriaty-The-Murderer/c621922429d1f6b4a91ed1fdc97798e3 to your computer and use it in GitHub Desktop.
tsalib demo
from tsalib import dim_var
from einops import rearrange
# Uppercase(abbr.):default_size
Batch = dim_var("Batch(b):64")
Dimension = dim_var("Dimension(d):128")
Heads = dim_var("Heads(h):8")
MaxLength = dim_var("MaxLength(l):80")
SrcVocabSize = dim_var("SrcVocabSize(sv)")
TargetVocabSize = dim_var("TargetVocabSize(tv)")
class MultiHeadAttention(nn.Module):
def __init__(self, heads, dimension, dropout=0.1):
super().__init__()
self.dimension = dimension
self.d_k = dimension // heads
self.heads = heads
self.q_linear = nn.Linear(dimension, dimension)
self.k_linear = nn.Linear(dimension, dimension)
self.v_linear = nn.Linear(dimension, dimension)
self.dropout = nn.Dropout(dropout)
self.output = nn.Linear(dimension, dimension)
def self_attention(self, q, k, v, mask=None):
scores: (Batch, Heads, MaxLength, MaxLength) = matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
mask: (Batch, 1, 1, MaxLength) = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = functional.F.softmax(scores, dim=-1)
if self.dropout is not None:
scores = self.dropout(scores)
output: (Batch, Heads, MaxLength, Dimension // Heads) = matmul(scores, v)
return output
def forward(self,
q: (Batch, MaxLength, Dimension),
k: (Batch, MaxLength, Dimension),
v: (Batch, MaxLength, Dimension),
mask: (Batch, 1, MaxLength) = None
):
q: (Batch, MaxLength, Heads, Dimension // Heads) = rearrange(self.q_linear(q),
'b l (h d) -> b h l d', h=self.heads)
k: (Batch, MaxLength, Heads, Dimension // Heads) = rearrange(self.k_linear(k),
'b l (h d) -> b h l d', h=self.heads)
v: (Batch, MaxLength, Heads, Dimension // Heads) = rearrange(self.v_linear(v),
'b l (h d) -> b h l d', h=self.heads)
scores: (Batch, Heads, MaxLength, Dimension // Heads) = self.self_attention(q, k, v, mask)
concat: (Batch, MaxLength, Dimension) = rearrange(scores, 'b h l d -> b l (h d)')
output: (Batch, MaxLength, Dimension) = self.output(concat)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment