Skip to content

Instantly share code, notes, and snippets.

@rohitdavas
Last active January 22, 2023 04:57
Show Gist options
  • Save rohitdavas/7a190d8d2ec176a51e7b712bbedee05a to your computer and use it in GitHub Desktop.
Save rohitdavas/7a190d8d2ec176a51e7b712bbedee05a to your computer and use it in GitHub Desktop.
SImple Self-Attention with vectorised form understanding
# %%
import torch
import torch.nn.functional as F # for using softmax
cc = torch.manual_seed(0) # for reproducibility
# somehow on my m1 mac, the randomness is not reproducible. todo: figure out why
b = 1 # mini batch size
t = 3 # sequence length
k = 2 # dimension of each vector in the sequence
# create a data for consistent use in the notebook
X_batched = torch.randn(b, t, k)
X = torch.randn(t, k)
print(f"""
X_batched.shape: {X_batched.shape}
X.shape: {X.shape}
X_batched: {X_batched}
X: {X}
""")
# %% [markdown]
# # Simple self-attention
#
# ```python
#
# """
# ------------
# CONVENTIONS:
# ------------
#
# formula for self-attention simple version:
#
# 1. Given X = [x1,
# x2,
# ..,
# ..,
# ..,
# xt] where each xi is a vector of dimension k
#
# shape of X is (t, k)
#
# 2. We want to compute the self-attention weights for each xi where i goes from i = 0, to t-1 in X
#
# E.g
#
# y1, y2, y3, y4, y5, y6
# | | | | | |
# ------------------------
# Self - attention weights
# ------------------------
# | | | | | |
# x1, x2, x3, x4, x5, x6
#
# y1 = W11 * x1 + W12 * x2 + W13 * x3 + W14 * x4 + W15 * x5 + W16 * x6
#
# W11 : weight for x1
# W12 : weight for x2
# W13 : weight for x3 and so on
#
# 3. finding W
#
# for a single Y, we find the weight of each X by computing the dot product of the corresponding X with other time steps of X including itself.
#
# W1 = [x1, x2, x3, x4, x5, x6] * [x1, x2, x3, x4, x5, x6]T
# and normalise the contribution of each weight by softmax.
#
# """
#
# ```
# %% [markdown]
# ## Simple for loop way
# %%
def calc_self_attention(X):
# X is a sequence of vectors
# Filling of W matrix would be like this:
# W11, W12, W13, W14, W15, W16 # weights for y1
# W21, W22, W23, W24, W25, W26 # weights for y2
# W31, W32, W33, W34, W35, W36 # weights for y3
# W41, W42, W43, W44, W45, W46 # weights for y4
# W51, W52, W53, W54, W55, W56 # weights for y5
# W61, W62, W63, W64, W65, W66 # weights for y6
W = torch.zeros(t, t)
# calculate W matrix
for i in range(t):
x_i = X[i] # a vector of dimension k
for j in range(t):
W[i][j] = torch.dot(x_i, X[j])
# calculate self-attention weights by softmax over dim 1
W = F.softmax(W, dim=1)
return W # self-attention weights
W = calc_self_attention(X)
# now calculate the self-attention vectors
Y = torch.zeros_like(X)
for i in range(t):
for j in range(t):
Y[i] += W[i][j] * X[j]
for i in range(t):
inp = X
weights = W[i]
out = Y[i]
print(f"input: {inp}")
print(f"weigh: {weights}")
print(f"outpu: {out}")
print()
# %% [markdown]
# ## Matrix multiplication way
# %%
# vectorized version
def calculate_y(X, W):
"""
Parameters
----------
X : torch.tensor
X is a sequence of vectors of shape t x k
W : torch.tensor
W is a matrix of shape t x t.
Returns
-------
Y : torch.tensor
Y is a sequence of vectors of shape t x k
Notes
-----
W is the self-attention weights matrix
X is input sequence of vectors
Y is output sequence of vectors
t : sequence length
k : dimension of each vector in the sequence
W = [w11, w12, ..., w1t]
[w21, w22, ..., w2t]
[.., .., ..., ..]
[wt1, wt2, ..., wtt]
each row of W is the self-attention weights for a vector in X
that is y1 = w11 * x1 + w12 * x2 + ... + w1t * xt
and
Y = [w11 * x1 + w12 * x2 + ... + w1t * xt]
[w21 * x1 + w22 * x2 + ... + w2t * xt]
[.., .., ..., ..]
[wt1 * x1 + wt2 * x2 + ... + wtt * xt ]
where each x1 is a vector of dimension k
resulting Y is a sequence of vectors of shape t x k
Y = ^ [w11 * x11 + w12 * x21 + ... + w1t * xt1, w11 * x12 + w12 * x22 + ... + w1t * xt2, ..., w11 * x1k + w12 * x2k + ... + w1t * xtk]
| [w21 * x11 + w22 * x21 + ... + w2t * xt1, w21 * x12 + w22 * x22 + ... + w2t * xt2, ..., w21 * x1k + w22 * x2k + ... + w2t * xtk]
t rows [.., .., ..., ..]
| [wt1 * x11 + wt2 * x21 + ... + wtt * xt1, wt1 * x12 + wt2 * x22 + ... + wtt * xt2, ..., wt1 * x1k + wt2 * x2k + ... + wtt * xtk]
<------------- k columns------------------>
"""
Y = W @ X
return Y # self-attention vectors for a sequence
# %%
def calc_self_attention(X):
"""
Parameters
----------
X : torch.tensor
X is a sequence of vectors of shape t x k
Returns
-------
W : torch.tensor
W is a matrix of shape t x t.
Notes
-----
W is the self-attention weights matrix
X is input sequence of vectors
t : sequence length
k : dimension of each vector in the sequence
now calculation of attention vectors is
Y = W @ X
for this a row of W reflects the self-attention weights for a vector in X
that is y1 = w11 * x1 + w12 * x2 + ... + w1t * xt
Now finding the W matrix would be like this:
W11, W12, W13, W14, W15, W16 # weights for y1
W11 : reflects the x1 and attention with x1
W12 : reflects the x1 and attention with x2
W13 : reflects the x1 and attention with x3
and so on
W21, W22, W23, W24, W25, W26 # weights for y2
W21 reflects the x2 and attention with x1
W22 reflects the x2 and attention with x2
and so on
considering this we can calculate the W matrix
W = X @ X.T
Now we normalise the contribution of each vector in X by softmax
W = F.softmax(W, dim=1)
y1 = W11 * x1 + W12 * x2 + ... + W1t * xt # k dimension
y2 = W21 * x1 + W22 * x2 + ... + W2t * xt # k dimension
"""
W = X @ X.T # calculate the self-attention weights
W = F.softmax(W, dim=1) # normalise the weights by softmax
return W
# %%
W_vectorised = calc_self_attention(X)
Y_vectorised = calculate_y(X, W_vectorised) # self-attention vectors for a sequence
# %%
Y_vectorised
# %%
Y
# %%
torch.allclose(Y, Y_vectorised)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment