Skip to content

Instantly share code, notes, and snippets.

@arif9799
Created August 9, 2023 21:33
Show Gist options
  • Save arif9799/0b652e1e3dcf7ec0f1db40b5a8f627ab to your computer and use it in GitHub Desktop.
Save arif9799/0b652e1e3dcf7ec0f1db40b5a8f627ab to your computer and use it in GitHub Desktop.
Attention is not enough! Multihead Attention in Transformers
from manim import *
from manim.utils.unit import Percent, Pixels
from colour import Color
import gensim
from gensim.models import Word2Vec
import numpy as np
import torch as t
config.frame_width = 16*8
config.frame_height = 9*8
HIGH = True
if HIGH:
print("Its high")
config.pixel_height = 2160
config.pixel_width = 3840
else:
print("Its low")
config.pixel_height = 1080
config.pixel_width = 1920
################################################################################
# Creating Word Vectors #
################################################################################
def construct_Word_Embeddings(sentences, dim):
data = sentences
dimension = dim
model = Word2Vec(data,
min_count = 1,
vector_size = dimension,
window = 3)
sequences = []
for seq in data:
sequence = []
for word in seq:
word_emb = model.wv.get_vector(word, True).reshape(1,-1)
word_emb = t.from_numpy(word_emb)
sequence.append(word_emb)
sequences.append(t.round(t.cat(sequence, dim=0), decimals=2))
return sequences
################################################################################
# Creating Positional Encoding Vector #
################################################################################
def construct_Positional_Encoding(input):
x = input.unsqueeze(0)
max_len = 10000
input_dimension = x.shape[-1]
seq_len = x.shape[-2]
dimension_indices = t.arange(0,input_dimension, 2)
denominator = t.pow(max_len, dimension_indices/input_dimension)
numerator = t.arange(0, seq_len, 1).reshape(seq_len, 1)
even_positions_encoded = t.sin(numerator/denominator)
odd_positions_encoded = t.cos(numerator/denominator)
combined = t.stack([even_positions_encoded, odd_positions_encoded], dim = 2)
pe = combined.reshape(1,seq_len,input_dimension)
pe = pe.squeeze(0)
x = x.squeeze(0)
pe = t.round(pe, decimals=2)
x_pe = t.round(x + pe ,decimals=2)
return x, pe, x_pe
################################################################################
# FUNCTION TO BUILD MATRICES WITH DIMENSIONS AND CAPTIONS #
################################################################################
def buildMatrices(Rs, # Rs: number of rows of matrix to build
Cs, # Cs: number of cols of matrix to build
roundDecimals, # roundDecimals: number of places to round off the values to
whereTo = None, # whereTo: Where to exactly place the matrix near the reference mobject (DOWN, UP, LEFT or RIGHT)
referenceMobj = None, # referenceMobj: The Mobject w.r.t which the matrix is to be placed on screen and add updater of it
arrayToMatrix = None, # arrayToMatrix: The array whose matrix Mobject is to be built
func = None, # func: if not an array (if arrayToMatrix is None), the function of torch using which array is to be generated to build matrix mobject
clr = WHITE, # clr: color of the matrix
cap = None, # cap: Caption or name of the matrix
hgt = 6, # hgt: height factor of Matrix
wdt = 1.5, # wdt: width factor of Matrix
seed = 786, # seed: the value to maintain stability of randomness
lRowNames = None, # row names on left side of the matrix -- if any
rRowNames = None, # row names on right side of the matrix -- if any
colNames = None, # column names of the matrix
):
if arrayToMatrix is None and func is None:
print("No Input to process and build a Matrix upon")
return None
# ------------> Set a random seed for weights initialization and build a matrix
iniMat= None
t.manual_seed(seed=seed)
# ini = t.round(func((Rs, Cs)), decimals=roundDecimals) if (arrayToMatrix is None) else t.round(arrayToMatrix, decimals= roundDecimals)
if isinstance(arrayToMatrix, t.Tensor) or (func is not None):
ini = t.round(func((Rs, Cs)), decimals=roundDecimals) if (arrayToMatrix is None) else t.round(arrayToMatrix, decimals= roundDecimals)
iniMat = Matrix(matrix=ini.numpy(), v_buff= hgt * Rs * 0.02, h_buff = wdt * Cs * 0.125).set(color=clr)
elif isinstance(arrayToMatrix, np.ndarray):
ini = arrayToMatrix
iniMat = Matrix(matrix=arrayToMatrix, v_buff= hgt * Rs * 0.02, h_buff = wdt * Cs * 0.125).set(color=clr)
else:
print("IN HERE")
pass
# ------------> locate the matrix somewhere w.r.t reference mobject
if referenceMobj is None:
pass
else:
iniMat.next_to(referenceMobj, whereTo)
iniMat.add_updater( lambda x, y = referenceMobj: x.next_to(y, whereTo))
# ------------> Create the MASTER MATRIX, the VGroup and add the main matrix
iniMatrix = VGroup()
iniMatrix.add(iniMat)
# ------------> Create dimensions text form and append it to Matrix with updaters
dimText = MathTex('(%d, %d)'%(Rs, Cs), font_size = 40).move_to(iniMat.get_critical_point(DR)+DOWN/2)
dimText.add_updater(lambda x, y = iniMat: x.move_to(y.get_critical_point(DR)+ (DOWN * iniMat.height * 0.1) ))
iniMatrix.add(dimText)
# ------------> Create caption if there is any text passed into 'cap' argument and append it to Matrix with updaters
if cap is not None:
capText = MathTex('%s'%(cap), font_size = 64).move_to(iniMat.get_critical_point(DOWN) + DOWN*1.25 )
capText.add_updater(lambda x, y = iniMat: x.move_to(y.get_critical_point(DOWN) + DOWN*(iniMat.height * 0.4) ))
iniMatrix.add(capText)
# ------------> Create and write left row names if row names are passed into 'lRowNames' argument and append it to Matrix with updaters
iniMatRows = iniMat.get_rows()
iniMatBracks = iniMat.get_brackets()
if lRowNames is not None:
iniMatRowNamesL =VGroup(*[MathTex(lRowNames[i])
.move_to([iniMatBracks[0].get_coord(0),iniMatRows[i].get_coord(1), 0]).shift(LEFT*len(lRowNames[i])*iniMat.height*0.06)
.add_updater(lambda x, y = iniMatRows[i], b =iniMatBracks[0]: x.move_to([b.get_coord(0), y.get_coord(1), 0 ]).shift(LEFT*len(lRowNames[i])*iniMat.height*0.06))
for i in range(len(iniMatRows))])
iniMatrix.add(iniMatRowNamesL)
# ------------> Create and write right row names if row names are passed into 'rRowNames' argument and append it to Matrix with updaters
if rRowNames is not None:
iniMatRowNamesR =VGroup(*[MathTex(rRowNames[i])
.move_to([iniMatBracks[1].get_coord(0),iniMatRows[i].get_coord(1), 0]).shift(RIGHT*len(rRowNames[i])*iniMat.height*0.06)
.add_updater(lambda x, y = iniMatRows[i], b =iniMatBracks[1]: x.move_to([b.get_coord(0), y.get_coord(1), 0 ]).shift(RIGHT*len(rRowNames[i])*iniMat.height*0.06))
for i in range(len(iniMatRows))])
iniMatrix.add(iniMatRowNamesR)
# ------------> Create and write column names if col names are passed into 'colNames' argument and append it to Matrix with updaters
if colNames is not None:
iniMatCols = iniMat.get_columns()
iniMatColNames = VGroup( * [MathTex(colNames[i])
.next_to(iniMatCols[i], UP, buff=iniMat.height/5)
.add_updater(lambda x, y = iniMatCols[i]: x.next_to(y, UP, buff=iniMat.height/7)) for i in range(len(iniMatCols))] )
iniMatrix.add(iniMatColNames)
return iniMatrix, ini
class Figure11(MovingCameraScene):
def construct(self):
# INTRO ANIMATION GOES HERE
def myAnimation(wordsForIntro: str):
fontHeight = config.frame_height//8
fontColor = WHITE
timePerChar = 0.1
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3)
self.play(Broadcast(C), run_time=1)
self.add(C)
h = config.frame_height // 1.5
svg2 = SVGMobject("/Users/arifwaghbakriwala/Desktop/Northeastern/Projects/Manimations/assets/svg/pl2.svg",
height=config.frame_height,
width= config.frame_width,
stroke_color=MAROON,
stroke_width=7,
fill_color=BLUE,
fill_opacity=0
)
self.play(Write(svg2))
self.add(svg2)
# Building text objects of individual characters.
wordsMath = VGroup()
for word in wordsForIntro:
charTex = VGroup()
for i,ch in enumerate(word):
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight)
if i != 0:
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN)
else:
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN)
charTex.add(chTex)
wordsMath.add(charTex)
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION
for wInd in range(len(wordsMath)):
for chInd in range(len(wordsMath[wInd])):
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar))
self.wait(0.5)
for chInd in reversed(range(len(wordsMath[wInd]))):
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar))
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//5, run_time=1.5))
self.play(AnimationGroup(*[ShrinkToCenter(C, run_time=0.75), ShrinkToCenter(svg2, run_time=0.75)]))
self.wait(0.5)
self.remove(C,svg2, wordsMath)
myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE'])
self.wait()
################################################################################
# FUNCTION TO ANIMATE MATRIX MULTIPLICATION #
################################################################################
def matrixMultiplicationAnimation(matrix_1: Matrix, matrix_2: Matrix, result: Matrix):
matrices = VGroup(*[matrix_1, matrix_2, result])
# self.play(self.camera.frame.animate.move_to(matrices).set(width = matrices.width * 5, height = matrices.height*7))
M1_rows = matrix_1.get_rows()
M2_cols = matrix_2.get_columns()
M3_rows = result.get_rows()
for rw_index in range(0, len(M1_rows)):
for cl_index in range(0, len(M2_cols)):
ag_final = AnimationGroup(*[Indicate(M1_rows[rw_index]),
Indicate(M2_cols[cl_index]),
FadeTransformPieces(mobject=VGroup(*[M1_rows[rw_index], M2_cols[cl_index]]).copy(), target_mobject=M3_rows[rw_index][cl_index])
], lag_ratio=0)
self.play(ag_final, run_time=0.35)
################################################################################
# FUNCTION TO PARENTHESISE MATRICES #
################################################################################
def parenthesizeMatrix(matrix, encapsulateName: str, referenceMobject: Mobject = None, position = None, bf=0.5, initial_scale_factor: float = 1.5):
"""
matrix: VGroup that contains matrix, dimension and caption
encapsulationName: Name of function in which matrix needs to be presented like sine, cosine, etc
cl: cell in which the parenthesized matrix needs to fit in
"""
parenMat = VGroup()
parens = MathTex("(", ")")
parens.scale(initial_scale_factor)
parens.stretch_to_fit_height(matrix.height)
mt= matrix[0]
l_paren, r_paren = parens.split()
l_paren.next_to(mt, LEFT, buff=bf)
r_paren.next_to(mt, RIGHT, buff=bf)
tx = MathTex(encapsulateName)
tx.scale(initial_scale_factor)
tx.next_to(l_paren, LEFT, buff=bf)
parenMat+=matrix
parenMat.add(*[l_paren, r_paren, tx])
# self.play(FadeTransformPieces(matrix, parenMat), run_time=0.75)
# if referenceMobject is not None:
# self.play(parenMat.animate.move_to(referenceMobject, position))
return parenMat
##############################################################################################################################################################
# #
# #
# IT STARTS HERE!!!! #
# #
# #
##############################################################################################################################################################
embeddingDimension = 6
hiddenDimension = embeddingDimension
data = [['Attention', 'is', 'not', 'enough'], ['Towards', 'Data', 'Science'], ['It', 'was', 'a', 'really', 'good', 'article']]
inputDimension = len(data[0])
sentences_Embedded = construct_Word_Embeddings(data, embeddingDimension)
_, _, pos_encoded = construct_Positional_Encoding(sentences_Embedded[0])
# ---------------> TITLE OF VIDEO
title = Title('Multi-Head Attention (Exemplifying One Input)', match_underline_width_to_text=True, underline_buff= MED_LARGE_BUFF, font_size = 256).shift(DOWN*2)
self.play(DrawBorderThenFill(title), run_time = 2)
# CREATING QUERY, KEY AND VALUE MATRICES (WITH CAPS)
# build query, key and values matrix with details like row and column names
_lRowNames = data[0]
_rRowNames = ["position_%d"%i for i in range(embeddingDimension)]
_colNames = ["d_%d"%i for i in range(embeddingDimension)]
_colors = [RED_E, GREEN_E, BLUE_E]
_qkvCaps = ['query', 'key', 'value']
qkvWithCaps = VGroup(*[ buildMatrices(Rs = pos_encoded.shape[0],
Cs= pos_encoded.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
arrayToMatrix=pos_encoded,
clr= _colors[i],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=2,
hgt=10
)[0].scale(2)
for i in range(len(_qkvCaps))]).arrange(DOWN, buff=4).to_edge(LEFT*4)
self.play(Write(qkvWithCaps))
self.wait()
# CREATING QUERY, KEY AND VALUE MATRICES (WITHOUT CAPS - LESS DETAILS)
# build query, key and value matrix without much details and only dimensions and name of matrix
_lRowNames = None
_rRowNames = None
_colNames = None
_colors = [RED_E, GREEN_E, BLUE_E]
_qkvCaps = ['query', 'key', 'value']
qkvWithoutCaps = VGroup(*[ buildMatrices(Rs = pos_encoded.shape[0],
Cs= pos_encoded.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
arrayToMatrix=pos_encoded,
clr= _colors[i],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=2,
hgt=10
)[0].scale(2)
for i in range(len(_qkvCaps))]).arrange(DOWN, buff=8).to_edge(LEFT*6)
self.play(FadeTransformPieces(qkvWithCaps, qkvWithoutCaps))
self.wait()
query, key, value = qkvWithoutCaps
# CREATING WEIGHT MATRICES TO PROJECT QUERY, KEY AND VALUE
# build query, key and value WEIGHTS matrix with only name and dimensions
_lRowNames = None
_rRowNames = None
_colNames = None
_colors = [RED_E, GREEN_E, BLUE_E]
_qkvCaps = ['Query\:_{Weights}', 'Key\:_{Weights}', 'Value\:_{Weights}']
uniform = t.distributions.uniform.Uniform(low=t.tensor(-1.0),high=t.tensor(1.0))
weightsQKV = [buildMatrices(Rs =hiddenDimension,
Cs= hiddenDimension,
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = uniform.sample,
arrayToMatrix=None,
clr= _colors[i],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=2,
hgt=8,
seed=786+i+1
)
for i in range(len(_qkvCaps))]
for i in range(len(weightsQKV)):
weightsQKV[i][0].scale(2).next_to(qkvWithoutCaps[i])
self.play(Write(VGroup(*[weightsQKV[0][0], weightsQKV[1][0], weightsQKV[2][0]])))
self.wait()
# PROJECTING QUERY, KEY AND VALUE USING WEIGHTS MATRICES (ALSO SOME ANIMATIONS) AND THEN PAINTING THEM IN DIFFERENT COLORS TO DIFFERENT HEADS WITHIN A SINGLE PROJECTED MATRIX
# Linear operations of key, query and value matrices
dimPerHead = 2
_lRowNames = None
_rRowNames = None
_colNames = ["d_%d"%i for i in range(embeddingDimension)]
_colors = [RED_E, GREEN_E, BLUE_E]
_multiColors = list(np.repeat([TEAL_E, YELLOW_E, ORANGE], dimPerHead))
_qkvCaps = ['(Q)', '(K)', '(V)']
_projQKV = [pos_encoded @ weightsQKV[0][1], pos_encoded @ weightsQKV[1][1], pos_encoded @ weightsQKV[2][1] ]
projQKV = [buildMatrices(Rs = _projQKV[i].shape[0],
Cs= _projQKV[i].shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_projQKV[i],
clr= _colors[i],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=2,
hgt=10
)
for i in range(len(_projQKV))]
#- place projected matrices in their positions
for i in range(len(projQKV)):
projQKV[i][0].scale(2).next_to(weightsQKV[i][0], RIGHT*4)
#- transform to projected matrices
for i,w,p in zip(range(3), weightsQKV, projQKV):
matrices = VGroup(*[qkvWithoutCaps[i], w[0], p[0]])
self.play(FadeTransformPieces(matrices[0:2], matrices[2]))
#- move projected matrices to left
moveAG = AnimationGroup(*[projQKV[i][0].animate.to_edge(LEFT*8) for i in range(len(projQKV))], lag_ratio=0.2)
self.play(moveAG)
self.wait()
# ------ different colors to columns to represent multiple heads within a single projected matrix
agOverall = []
for i in range(len(projQKV)):
ls = []
cols = projQKV[i][0][0].get_columns()
for c in range(len(cols)):
ls.append(cols[c].animate.set_color(_multiColors[c]))
agOverall.append(AnimationGroup(*ls, lag_ratio=0.75))
agOverall = AnimationGroup(*agOverall, lag_ratio=0)
self.play(agOverall)
self.wait()
# SPLITTING QUERY MATRICES
_lRowNames = None
_rRowNames = None
_colNames = ["d_%d"%i for i in range(embeddingDimension)]
_colors = [RED_E, GREEN_E, BLUE_E]
_qkvCaps = ['(Q_1)', '(Q_2)', '(Q_3)']
_querySplit = t.stack(t.split(projQKV[0][1], dimPerHead, dim=-1))#.swapaxes(0,1)
querySplit = [buildMatrices(Rs = _querySplit[i].shape[0],
Cs= _querySplit[i].shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_querySplit[i],
clr= _multiColors[i*dimPerHead: (i+1)*dimPerHead],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames[i*dimPerHead: (i+1)*dimPerHead],
wdt=6,
hgt=10
)
for i in range(len(_querySplit))]
querySplit[0][0].scale(2).next_to(projQKV[0][0], RIGHT*8)
self.play(FadeTransformPieces(projQKV[0][0][0].get_columns()[0:dimPerHead].copy(),querySplit[0][0]), run_time=2)
for i in range(1,len(querySplit)):
querySplit[i][0].scale(2).next_to(querySplit[i-1][0], RIGHT*4)
self.play(FadeTransformPieces(projQKV[0][0][0].get_columns()[(i) * dimPerHead : (i+1) * dimPerHead].copy(),querySplit[i][0]), run_time=2)
queryMatrices = VGroup(*[q[0] for q in querySplit])
# SPLITTING KEY MATRICES
_lRowNames = None
_rRowNames = None
_colNames = ["d_%d"%i for i in range(embeddingDimension)]
_colors = [RED_E, GREEN_E, BLUE_E]
_qkvCaps = ['(K_1)', '(K_2)', '(K_3)']
_keySplit = t.stack(t.split(projQKV[1][1], dimPerHead, dim=-1))#.swapaxes(0,1)
keySplit = [buildMatrices(Rs = _keySplit[i].shape[0],
Cs= _keySplit[i].shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_keySplit[i],
clr= _multiColors[i*dimPerHead: (i+1)*dimPerHead],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames[i*dimPerHead: (i+1)*dimPerHead],
wdt=6,
hgt=10
)
for i in range(len(_keySplit))]
keySplit[0][0].scale(2).next_to(projQKV[1][0], RIGHT*8)
self.play(FadeTransformPieces(projQKV[1][0][0].get_columns()[0:dimPerHead].copy(),keySplit[0][0]), run_time=2)
for i in range(1,len(keySplit)):
keySplit[i][0].scale(2).next_to(keySplit[i-1][0], RIGHT*4)
self.play(FadeTransformPieces(projQKV[1][0][0].get_columns()[(i) * dimPerHead : (i+1) * dimPerHead].copy(),keySplit[i][0]), run_time=2)
keyMatrices = VGroup(*[q[0] for q in keySplit])
# TRANSPOSING SPLITTED KEY MATRICES
_lRowNames = ["d_%d "%i for i in range(embeddingDimension)]
_rRowNames = None
_colNames = None
_colors = [RED_E, GREEN_E, BLUE_E]
_qkvCaps = ['(K^T_{1})', '(K^T_{2})', '(K^T_{3})']
_keyTSplit = t.transpose(t.stack(t.split(projQKV[1][1], dimPerHead, dim=-1)), dim0=-2, dim1=-1)#.swapaxes(0,1)
keyTSplit = [buildMatrices(Rs = _keyTSplit[i].shape[0],
Cs= _keyTSplit[i].shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_keyTSplit[i],
clr= _multiColors[i*dimPerHead],
cap= _qkvCaps[i],
lRowNames= _lRowNames[i*dimPerHead: (i+1)*dimPerHead],
rRowNames= _rRowNames,
colNames=_colNames,
wdt=4,
hgt=18
)
for i in range(len(_keyTSplit))]
keyTSplit[0][0].scale(2).next_to(keyMatrices, RIGHT*8)
self.play(FadeTransformPieces(keySplit[0][0],keyTSplit[0][0]), run_time=2)
for i in range(1,len(querySplit)):
keyTSplit[i][0].scale(2).next_to(keyTSplit[i-1][0], RIGHT*4)
self.play(FadeTransformPieces(keySplit[i][0],keyTSplit[i][0]), run_time=2)
keyTMatrices = VGroup(*[q[0] for q in keyTSplit])
# ARRANGING THE SPLITTED QUERY AND KEY TRANSPOSED MATRICES CLOSER TO TRANSFORM THEM INTO ONE
self.play(queryMatrices.animate.arrange(DOWN, buff = 2).next_to(VGroup(projQKV[0][0],projQKV[1][0],projQKV[2][0]), RIGHT*16))
for i in range(len(keyTMatrices)):
self.play(keyTMatrices[i].animate.next_to(queryMatrices[i], RIGHT*8))
# CALCULATING ATTENTION HEADS
_lRowNames = None
_rRowNames = None
_colNames = None
_colors = [RED_E, GREEN_E, BLUE_E]
_qkvCaps = ['Attention \: Head\:_{%d} \: (Q_%d \: . \: K^{T}_{%d})'%((i+1),(i+1),(i+1)) for i in range(len(querySplit))]
_attentionHeads = [querySplit[i][1] @ keyTSplit[i][1] for i in range(len(querySplit))]
attentionHeads = [buildMatrices(Rs = _attentionHeads[i].shape[0],
Cs= _attentionHeads[i].shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_attentionHeads[i],
clr= _multiColors[i*dimPerHead],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=4,
hgt=10
)
for i in range(len(_attentionHeads))]
for i in range(len(attentionHeads)):
attentionHeads[i][0].scale(2).next_to(keyTMatrices[i], RIGHT*8)
self.play(Write(VGroup(*[attentionHeads[i][0][0].get_brackets(),attentionHeads[i][0][1],attentionHeads[i][0][2]])), run_time=0.25)
matrixMultiplicationAnimation(matrix_1= queryMatrices[i][0], matrix_2=keyTMatrices[i][0], result=attentionHeads[i][0][0])
self.play(FadeOut(queryMatrices[i], keyTMatrices[i]))
# SCALING ATTENTION HEADS v1
_lRowNames = None
_rRowNames = None
_colNames = None
_qkvCaps = ['Scaled\: Attention \: Head\:_{%d} \: \\frac{(Q_{%d}\: . \:K^{T}_{%d})}{\sqrt{d_k}}'%((i+1),(i+1),(i+1)) for i in range(len(querySplit))]
_attHeadsScaled = [t.round(attentionHeads[i][1], decimals=2).numpy().astype('str') for i in range(len(attentionHeads))]
_attHeadsScaledArray = [ np.vectorize(lambda x: x+'/\sqrt{d_{k}}')(_attHeadsScaled[i]) for i in range(len(attentionHeads))]
attHeadsScaled_v1 = [buildMatrices(Rs = _attHeadsScaledArray[i].shape[0],
Cs= _attHeadsScaledArray[i].shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_attHeadsScaledArray[i],
clr= _multiColors[i*dimPerHead],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=7,
hgt=12
)
for i in range(len(_attHeadsScaledArray))]
for i in range(len(attHeadsScaled_v1)):
attHeadsScaled_v1[i] = (parenthesizeMatrix(matrix=attHeadsScaled_v1[i][0], encapsulateName='Scaled', bf=0.9), attHeadsScaled_v1[i][1])
attHeadsScaled_v1[i][0].scale(2).set_x(attentionHeads[i][0].get_x()).set_y(attentionHeads[i][0].get_y()).set_z(attentionHeads[i][0].get_z())
self.play(FadeTransformPieces(attentionHeads[i][0], attHeadsScaled_v1[i][0]))
# SCALING ATTENTION HEADS v2
_lRowNames = None
_rRowNames = None
_colNames = None
_qkvCaps = ['Scaled \:Attention \: Head\:_{%d} \: \\frac{(Q_{%d}\: . \:K^{T}_{%d})}{\sqrt{d_k}}'%((i+1),(i+1),(i+1)) for i in range(len(querySplit))]
_attHeadsScaled = [t.round(attentionHeads[i][1], decimals=2).numpy().astype('str') for i in range(len(attentionHeads))]
_attHeadsScaledArray = [ np.vectorize(lambda x: x+'/\sqrt{%d}'%dimPerHead)(_attHeadsScaled[i]) for i in range(len(attentionHeads))]
attHeadsScaled_v2 = [buildMatrices(Rs = _attHeadsScaledArray[i].shape[0],
Cs= _attHeadsScaledArray[i].shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_attHeadsScaledArray[i],
clr= _multiColors[i*dimPerHead],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=7,
hgt=12
)
for i in range(len(_attHeadsScaledArray))]
for i in range(len(attHeadsScaled_v2)):
attHeadsScaled_v2[i] = (parenthesizeMatrix(matrix=attHeadsScaled_v2[i][0], encapsulateName='Scaled', bf=0.9), attHeadsScaled_v2[i][1])
attHeadsScaled_v2[i][0].scale(2).set_x(attHeadsScaled_v1[i][0].get_x()).set_y(attHeadsScaled_v1[i][0].get_y()).set_z(attHeadsScaled_v1[i][0].get_z())
self.play(FadeTransformPieces(attHeadsScaled_v1[i][0], attHeadsScaled_v2[i][0]))
# SCALING RESULT OF MATMUL BLOCK v3
_lRowNames = None
_rRowNames = None
_colNames = None
_qkvCaps = ['\\frac{(Q_{%d}\: . \:K^{T}_{%d})}{\sqrt{d_k}}'%((i+1),(i+1)) for i in range(len(querySplit))]
_attHeadsScaledArray = [ t.round(attentionHeads[i][1]/hiddenDimension, decimals=2) for i in range(len(attentionHeads))]
attHeadsScaled_v3 = [buildMatrices(Rs = _attHeadsScaledArray[i].shape[0],
Cs= _attHeadsScaledArray[i].shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_attHeadsScaledArray[i],
clr= _multiColors[i*dimPerHead],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=4,
hgt=10
)
for i in range(len(_attHeadsScaledArray))]
for i in range(len(attHeadsScaled_v3)):
attHeadsScaled_v3[i] = (parenthesizeMatrix(matrix=attHeadsScaled_v3[i][0], encapsulateName='Scaled', bf=0.9), attHeadsScaled_v3[i][1])
attHeadsScaled_v3[i][0].scale(2).set_x(attHeadsScaled_v2[i][0].get_x()).set_y(attHeadsScaled_v2[i][0].get_y()).set_z(attHeadsScaled_v2[i][0].get_z())
self.play(FadeTransformPieces(attHeadsScaled_v2[i][0], attHeadsScaled_v3[i][0]))
attHeadsScaled= [buildMatrices(Rs = _attHeadsScaledArray[i].shape[0],
Cs= _attHeadsScaledArray[i].shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_attHeadsScaledArray[i],
clr= _multiColors[i*dimPerHead],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=4,
hgt=10
)
for i in range(len(_attHeadsScaledArray))]
for i in range(len(attHeadsScaled_v3)):
attHeadsScaled[i][0].scale(2).set_x(attHeadsScaled_v3[i][0].get_x()).set_y(attHeadsScaled_v3[i][0].get_y()).set_z(attHeadsScaled_v3[i][0].get_z())
self.play(FadeTransformPieces(attHeadsScaled_v3[i][0], attHeadsScaled[i][0]))
self.remove(attHeadsScaled_v1[0][0], attHeadsScaled_v2[0][0], attHeadsScaled_v3[0][0])
del attHeadsScaled_v1
del attHeadsScaled_v2
del attHeadsScaled_v3
# CONVERTING SCALED PROJECTED MATRICES TO SOFTMAX MATRICES
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
_lRowNames = None
_rRowNames = None
_colNames = None
_qkvCaps = ["Softmax \: \\biggl( \\frac{(Q_{%d}\: . \:K^{T}_{%d})}{\sqrt{d_k}} \\biggr)"%((i+1),(i+1)) for i in range(len(querySplit))]
scaledSoftmax = [t.softmax(attHeadsScaled[i][1], dim=-1) for i in range(len(attHeadsScaled))]
scaledQKSoftmax = [buildMatrices(Rs = scaledSoftmax[i].shape[0],
Cs= scaledSoftmax[i].shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=scaledSoftmax[i],
clr= _multiColors[i*dimPerHead],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=4,
hgt=10
)
for i in range(len(scaledSoftmax))]
for i in range(len(scaledQKSoftmax)):
parenSoft = (parenthesizeMatrix(matrix=attHeadsScaled[i][0], encapsulateName='Softmax', bf=0.9, initial_scale_factor=2), attHeadsScaled[i][1])
parenSoft[0].set_x(attHeadsScaled[i][0].get_x()).set_y(attHeadsScaled[i][0].get_y()).set_z(attHeadsScaled[i][0].get_z())
self.play(FadeTransformPieces(attHeadsScaled[i][0], parenSoft[0]))
self.wait()
scaledQKSoftmax[i][0].scale(2).set_x(parenSoft[0].get_x()).set_y(parenSoft[0].get_y()).set_z(parenSoft[0].get_z())
self.play(FadeTransformPieces(parenSoft[0], scaledQKSoftmax[i][0]))
self.wait()
self.remove(parenSoft[0])
del parenSoft
tempscaledQKSoftmax = VGroup(*[ scaledQKSoftmax[i][0] for i in range(len(scaledQKSoftmax))])
self.play(tempscaledQKSoftmax.animate.arrange(DOWN, buff=4).next_to(VGroup(*[projQKV[0][0], projQKV[1][0]]), RIGHT*48))
# SPLITTING VALUE MATRICES
_lRowNames = None
_rRowNames = None
_colNames = ["d_%d"%i for i in range(embeddingDimension)]
_colors = [RED_E, GREEN_E, BLUE_E]
_qkvCaps = ['(V_1)', '(V_2)', '(V_3)']
_valueSplit = t.stack(t.split(projQKV[2][1], dimPerHead, dim=-1))
valueSplit = [buildMatrices(Rs = _valueSplit[i].shape[0],
Cs= _valueSplit[i].shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_valueSplit[i],
clr= _multiColors[i*dimPerHead: (i+1)*dimPerHead],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames[i*dimPerHead: (i+1)*dimPerHead],
wdt=6,
hgt=10
)
for i in range(len(_valueSplit))]
valueSplit[0][0].scale(2).next_to(projQKV[2][0], RIGHT*8)
self.play(FadeTransformPieces(projQKV[2][0][0].get_columns()[0:dimPerHead].copy(),valueSplit[0][0]), run_time=2)
for i in range(1,len(valueSplit)):
valueSplit[i][0].scale(2).next_to(valueSplit[i-1][0], RIGHT*4)
self.play(FadeTransformPieces(projQKV[2][0][0].get_columns()[(i) * dimPerHead : (i+1) * dimPerHead].copy(),valueSplit[i][0]), run_time=2)
valueMatrices = VGroup(*[q[0] for q in valueSplit])
for i in range(len(valueMatrices)):
self.play(valueMatrices[i].animate.next_to(scaledQKSoftmax[i][0], RIGHT*8))
self.play(tempscaledQKSoftmax.animate.arrange(DOWN, buff=7).next_to(VGroup(*[projQKV[0][0], projQKV[1][0]]), RIGHT*48))
self.play(tempscaledQKSoftmax.animate.shift(DOWN*10))
for i in range(len(valueMatrices)):
self.play(valueMatrices[i].animate.next_to(scaledQKSoftmax[i][0], RIGHT*8))
# CALCULATING SOFTMAX RESULTS WITH VALUE MATRICES
_lRowNames = None
_rRowNames = None
_colNames = ["d_%d"%i for i in range(embeddingDimension)]
_colors = [RED_E, GREEN_E, BLUE_E]
_qkvCaps = ["Softmax \: \\biggl( \\frac{(Q_{%d}\: . \:K^{T}_{%d})}{\sqrt{d_k}} \\biggr) \: . \: V_{%d}"%((i+1),(i+1),(i+1)) for i in range(len(querySplit))]
_sftDotV = [scaledQKSoftmax[i][1] @ valueSplit[i][1] for i in range(len(scaledQKSoftmax))]
sftDotV = [buildMatrices(Rs = _sftDotV[i].shape[0],
Cs= _sftDotV[i].shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_sftDotV[i],
clr= _multiColors[i*dimPerHead],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames[i*dimPerHead: (i+1)*dimPerHead],
wdt=6,
hgt=16
)
for i in range(len(_sftDotV))]
for i in range(len(sftDotV)):
sftDotV[i][0].scale(2).next_to(valueMatrices[i], RIGHT*8)
self.play(Write(VGroup(*[sftDotV[i][0][0].get_brackets(),sftDotV[i][0][1],sftDotV[i][0][2]])), run_time=0.25)
matrixMultiplicationAnimation(matrix_1= scaledQKSoftmax[i][0][0], matrix_2=valueMatrices[i][0], result=sftDotV[i][0][0])
self.play(FadeOut(scaledQKSoftmax[i][0], valueMatrices[i]))
tempsftDotVT = VGroup(*[sftDotV[i][0] for i in range(len(sftDotV))])
self.play(tempsftDotVT.animate.arrange(RIGHT, buff=4).move_to(ORIGIN))
self.wait()
# CONCATENATED RESULT
_lRowNames = None
_rRowNames = None
_colNames = ["d_%d "%i for i in range(embeddingDimension)]
_colors = [RED_E, GREEN_E, BLUE_E]
_qkvCaps = ["Softmax \: \\biggl( \\frac{(Q \: . \:K^{T})}{\sqrt{d_k}} \\biggr) \: . \: V"]
_concatSftDotV = t.hstack(_sftDotV)
concatSftDotV = buildMatrices(Rs = _concatSftDotV.shape[0],
Cs= _concatSftDotV.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_concatSftDotV,
clr= BLUE_E,
cap= _qkvCaps[0],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=2,
hgt=10
)
concatSftDotV[0].scale(2).next_to(VGroup(*[projQKV[0][0], projQKV[1][0], projQKV[2][0]]), RIGHT*32)
self.play(FadeTransform(tempsftDotVT, concatSftDotV[0]))
# ------ different colors to columns to represent multiple heads within a single projected matrix
_multiColors = list(np.repeat([TEAL_E, YELLOW_E, ORANGE], dimPerHead))
agOverall = []
cols = concatSftDotV[0][0].get_columns()
for c in range(len(cols)):
agOverall.append(cols[c].animate.set_color(_multiColors[c]))
agOverall = AnimationGroup(*agOverall, lag_ratio=0.5)
self.play(agOverall)
self.wait()
# CREATING CONCATENATION WEIGHTS
_lRowNames = None
_rRowNames = None
_colNames = None
_qkvCaps = ['Concatenation \: ( \: or \: Output \: Weights \:)']
uniform = t.distributions.uniform.Uniform(low=t.tensor(-1.0),high=t.tensor(1.0))
concatWeights = buildMatrices(Rs =hiddenDimension,
Cs= hiddenDimension,
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = uniform.sample,
arrayToMatrix=None,
clr= GOLD_E,
cap= _qkvCaps[0],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=2,
hgt=8,
seed=786+i+1
)
concatWeights[0].scale(2).next_to(concatSftDotV[0], RIGHT*8)
self.play(Write(concatWeights[0]))
self.wait()
# FINAL ENCODED MATRIX MULTIPLICATION
_lRowNames = None
_rRowNames = None
_colNames = ["d_%d"%i for i in range(embeddingDimension)]
_qkvCaps = ['Final \: Encoder \: Result']
_finalResult = concatSftDotV[1] @ concatWeights[1]
finalResult = buildMatrices(Rs = _finalResult.shape[0],
Cs= _finalResult.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_finalResult,
clr= YELLOW,
cap= _qkvCaps[0],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=2,
hgt=10
)
finalResult[0].scale(2).next_to(concatWeights[0])
self.play(Write(VGroup(*[finalResult[0][0].get_brackets(), finalResult[0][1], finalResult[0][2]])), run_time=0.25)
matrixMultiplicationAnimation(matrix_1= concatSftDotV[0][0], matrix_2=concatWeights[0][0], result=finalResult[0][0])
self.play(FadeOut(concatSftDotV[0], concatWeights[0]))
self.play(finalResult[0].animate.move_to(ORIGIN))
self.play(finalResult[0].animate.scale(2.5))
self.play(Indicate(finalResult[0]), run_time= 5)
self.wait(15)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment