Created
August 8, 2023 02:19
-
-
Save arif9799/d61e2d6911be344ba28a6088ec0f2705 to your computer and use it in GitHub Desktop.
Attention is not enough! Self-Attention via Scaled Dot Product Attention
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from manim import * | |
from manim.utils.unit import Percent, Pixels | |
from colour import Color | |
import gensim | |
from gensim.models import Word2Vec | |
import numpy as np | |
import torch as t | |
config.frame_width = 16*7 | |
config.frame_height = 9*7 | |
HIGH = True | |
if HIGH: | |
print("Its high") | |
config.pixel_height = 2160 | |
config.pixel_width = 3840 | |
else: | |
print("Its low") | |
config.pixel_height = 1080 | |
config.pixel_width = 1920 | |
################################################################################ | |
# Creating Word Vectors # | |
################################################################################ | |
def construct_Word_Embeddings(sentences, dim): | |
data = sentences | |
dimension = dim | |
model = Word2Vec(data, | |
min_count = 1, | |
vector_size = dimension, | |
window = 3) | |
sequences = [] | |
for seq in data: | |
sequence = [] | |
for word in seq: | |
word_emb = model.wv.get_vector(word, True).reshape(1,-1) | |
word_emb = t.from_numpy(word_emb) | |
sequence.append(word_emb) | |
sequences.append(t.round(t.cat(sequence, dim=0), decimals=2)) | |
return sequences | |
################################################################################ | |
# Creating Positional Encoding Vector # | |
################################################################################ | |
def construct_Positional_Encoding(input): | |
x = input.unsqueeze(0) | |
max_len = 10000 | |
input_dimension = x.shape[-1] | |
seq_len = x.shape[-2] | |
dimension_indices = t.arange(0,input_dimension, 2) | |
denominator = t.pow(max_len, dimension_indices/input_dimension) | |
numerator = t.arange(0, seq_len, 1).reshape(seq_len, 1) | |
even_positions_encoded = t.sin(numerator/denominator) | |
odd_positions_encoded = t.cos(numerator/denominator) | |
combined = t.stack([even_positions_encoded, odd_positions_encoded], dim = 2) | |
pe = combined.reshape(1,seq_len,input_dimension) | |
pe = pe.squeeze(0) | |
x = x.squeeze(0) | |
pe = t.round(pe, decimals=2) | |
x_pe = t.round(x + pe ,decimals=2) | |
return x, pe, x_pe | |
################################################################################ | |
# SCALED DOT PRODUCT FIGURE CODE BELOW # | |
################################################################################ | |
def scaledDotProductFigure(): | |
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
sDPCaps = VGroup() # scaled dot product Captions | |
sDPBoxes = VGroup() # scaled dot product Boxes | |
sDPArrows = VGroup() # scaled dot product Boxes | |
vBuff = 18 | |
boxBuff = MED_LARGE_BUFF | |
# Q | |
_Q = Text("Q", font_size=96, slant=OBLIQUE).shift(DOWN*10) | |
sDPCaps.add(_Q) | |
# K | |
_K = Text("K", font_size=96, slant=OBLIQUE).next_to(_Q, RIGHT*20) | |
sDPCaps.add(_K) | |
# Matmul box 1 | |
_matMulTxt1 = Text("Matrix Multiplication 1", font_size=96, slant=OBLIQUE).next_to(VGroup(*[_Q, _K]), UP*vBuff) | |
_matMulBox1 = always_redraw(lambda : SurroundingRectangle(mobject=_matMulTxt1, color=PURPLE, fill_opacity = 0.25, fill_color = PURPLE, corner_radius=0.3, buff=boxBuff)) | |
sDPCaps.add(_matMulTxt1) | |
sDPBoxes.add(_matMulBox1) | |
# Scale Box | |
_scaleTxt = Text("Scale", font_size=96, slant=OBLIQUE).next_to(_matMulTxt1, UP*vBuff) | |
_scaleBox = always_redraw(lambda : SurroundingRectangle(mobject=_scaleTxt, color=YELLOW, fill_opacity = 0.25, fill_color = YELLOW, corner_radius=0.3, buff=boxBuff)) | |
sDPCaps.add(_scaleTxt) | |
sDPBoxes.add(_scaleBox) | |
# Mask Optional | |
_maskTxt = Text("Mask (Optional)", font_size=96, slant=OBLIQUE).next_to(_scaleTxt, UP*vBuff) | |
_maskBox = always_redraw(lambda : SurroundingRectangle(mobject=_maskTxt, color=PINK, fill_opacity = 0.25, fill_color = PINK, corner_radius=0.3, buff=boxBuff)) | |
sDPCaps.add(_maskTxt) | |
sDPBoxes.add(_maskBox) | |
# Softmax Optional | |
_sFTxt = Text("Softmax", font_size=96, slant=OBLIQUE).next_to(_maskTxt, UP*vBuff) | |
_sFBox = always_redraw(lambda : SurroundingRectangle(mobject=_sFTxt, color=GREEN, fill_opacity = 0.25, fill_color = GREEN, corner_radius=0.3, buff=boxBuff)) | |
sDPCaps.add(_sFTxt) | |
sDPBoxes.add(_sFBox) | |
# V | |
_V = Text("V", font_size=96, slant=OBLIQUE).next_to(_sFTxt, RIGHT*15) | |
sDPCaps.add(_V) | |
# Matmul box 2 | |
_matMulTxt2 = Text("Matrix Multiplication 2", font_size=96, slant=OBLIQUE).next_to(VGroup(*[_V, _sFTxt]), UP*vBuff) | |
_matMulBox2 = always_redraw(lambda : SurroundingRectangle(mobject=_matMulTxt2, color=PURPLE, fill_opacity = 0.25, fill_color = PURPLE, corner_radius=0.3, buff=boxBuff)) | |
sDPCaps.add(_matMulTxt2) | |
sDPBoxes.add(_matMulBox2) | |
# _Q to matmul box 1 | |
_Q_to_matMulBox1 = always_redraw(lambda : Arrow(start = _Q.get_top(), end = [_Q.get_top()[0], _matMulBox1.get_bottom()[1],0], stroke_width=15, buff=0.2)) | |
sDPArrows.add(_Q_to_matMulBox1) | |
# _K to matmul box 1 | |
_K_to_matMulBox1 = always_redraw(lambda : Arrow(start = _K.get_top(), end = [_K.get_top()[0], _matMulBox1.get_bottom()[1],0], stroke_width=15, buff=0.2)) | |
sDPArrows.add(_K_to_matMulBox1) | |
# matMul Box 1 to scaleBox | |
_matMulBox1_to_scaleBox = always_redraw(lambda : Arrow(start = _matMulBox1.get_top(), end = _scaleBox.get_bottom(), stroke_width=15, buff=0.2)) | |
sDPArrows.add(_matMulBox1_to_scaleBox) | |
# scaleBox to MaskBox | |
_scaleBox_to_maskBox = always_redraw(lambda : Arrow(start = _scaleBox.get_top(), end = _maskBox.get_bottom(), stroke_width=15, buff=0.2)) | |
sDPArrows.add(_scaleBox_to_maskBox) | |
# MaskBox to softmax | |
sFMax_to_scaleBox = always_redraw(lambda : Arrow(start = _maskBox.get_top(), end = _sFBox.get_bottom(), stroke_width=15, buff=0.2)) | |
sDPArrows.add(sFMax_to_scaleBox) | |
# softmax to matmul box 2 | |
_sFBox_to_matMulBox2 = always_redraw(lambda : Arrow(start = _sFBox.get_top(), end = [_sFBox.get_top()[0], _matMulBox2.get_bottom()[1],0], stroke_width=15, buff=0.2)) | |
sDPArrows.add(_sFBox_to_matMulBox2) | |
# _V to matmul box 2 | |
_V_to_matMulBox2 = always_redraw(lambda : Arrow(start = _V.get_top(), end = [_V.get_top()[0], _matMulBox2.get_bottom()[1],0], stroke_width=15, buff=0.2)) | |
sDPArrows.add(_V_to_matMulBox2) | |
return sDPCaps, sDPBoxes, sDPArrows | |
################################################################################ | |
# FUNCTION TO BUILD MATRICES WITH DIMENSIONS AND CAPTIONS # | |
################################################################################ | |
def buildMatrices(Rs, # Rs: number of rows of matrix to build | |
Cs, # Cs: number of cols of matrix to build | |
roundDecimals, # roundDecimals: number of places to round off the values to | |
whereTo = None, # whereTo: Where to exactly place the matrix near the reference mobject (DOWN, UP, LEFT or RIGHT) | |
referenceMobj = None, # referenceMobj: The Mobject w.r.t which the matrix is to be placed on screen and add updater of it | |
arrayToMatrix = None, # arrayToMatrix: The array whose matrix Mobject is to be built | |
func = None, # func: if not an array (if arrayToMatrix is None), the function of torch using which array is to be generated to build matrix mobject | |
clr = WHITE, # clr: color of the matrix | |
cap = None, # cap: Caption or name of the matrix | |
hgt = 6, # hgt: height factor of Matrix | |
wdt = 1.5, # wdt: width factor of Matrix | |
seed = 786, # seed: the value to maintain stability of randomness | |
lRowNames = None, # row names on left side of the matrix -- if any | |
rRowNames = None, # row names on right side of the matrix -- if any | |
colNames = None, # column names of the matrix | |
): | |
if arrayToMatrix is None and func is None: | |
print("No Input to process and build a Matrix upon") | |
return None | |
# ------------> Set a random seed for weights initialization and build a matrix | |
iniMat= None | |
t.manual_seed(seed=seed) | |
# ini = t.round(func((Rs, Cs)), decimals=roundDecimals) if (arrayToMatrix is None) else t.round(arrayToMatrix, decimals= roundDecimals) | |
if isinstance(arrayToMatrix, t.Tensor) or (func is not None): | |
ini = t.round(func((Rs, Cs)), decimals=roundDecimals) if (arrayToMatrix is None) else t.round(arrayToMatrix, decimals= roundDecimals) | |
iniMat = Matrix(matrix=ini.numpy(), v_buff= hgt * Rs * 0.02, h_buff = wdt * Cs * 0.125).set(color=clr) | |
elif isinstance(arrayToMatrix, np.ndarray): | |
ini = arrayToMatrix | |
iniMat = Matrix(matrix=arrayToMatrix, v_buff= hgt * Rs * 0.02, h_buff = wdt * Cs * 0.125).set(color=clr) | |
else: | |
print("IN HERE") | |
pass | |
# ------------> locate the matrix somewhere w.r.t reference mobject | |
if referenceMobj is None: | |
pass | |
else: | |
iniMat.next_to(referenceMobj, whereTo) | |
iniMat.add_updater( lambda x, y = referenceMobj: x.next_to(y, whereTo)) | |
# ------------> Create the MASTER MATRIX, the VGroup and add the main matrix | |
iniMatrix = VGroup() | |
iniMatrix.add(iniMat) | |
# ------------> Create dimensions text form and append it to Matrix with updaters | |
dimText = MathTex('(%d, %d)'%(Rs, Cs), font_size = 40).move_to(iniMat.get_critical_point(DR)+DOWN/2) | |
dimText.add_updater(lambda x, y = iniMat: x.move_to(y.get_critical_point(DR)+ (DOWN * iniMat.height * 0.1) )) | |
iniMatrix.add(dimText) | |
# ------------> Create caption if there is any text passed into 'cap' argument and append it to Matrix with updaters | |
if cap is not None: | |
capText = MathTex('%s'%(cap), font_size = 64).move_to(iniMat.get_critical_point(DOWN) + DOWN*1.25 ) | |
capText.add_updater(lambda x, y = iniMat: x.move_to(y.get_critical_point(DOWN) + DOWN*(iniMat.height * 0.3) )) | |
iniMatrix.add(capText) | |
# ------------> Create and write left row names if row names are passed into 'lRowNames' argument and append it to Matrix with updaters | |
iniMatRows = iniMat.get_rows() | |
iniMatBracks = iniMat.get_brackets() | |
if lRowNames is not None: | |
iniMatRowNamesL =VGroup(*[MathTex(lRowNames[i]) | |
.move_to([iniMatBracks[0].get_coord(0),iniMatRows[i].get_coord(1), 0]).shift(LEFT*len(lRowNames[i])*iniMat.height*0.06) | |
.add_updater(lambda x, y = iniMatRows[i], b =iniMatBracks[0]: x.move_to([b.get_coord(0), y.get_coord(1), 0 ]).shift(LEFT*len(lRowNames[i])*iniMat.height*0.06)) | |
for i in range(len(iniMatRows))]) | |
iniMatrix.add(iniMatRowNamesL) | |
# ------------> Create and write right row names if row names are passed into 'rRowNames' argument and append it to Matrix with updaters | |
if rRowNames is not None: | |
iniMatRowNamesR =VGroup(*[MathTex(rRowNames[i]) | |
.move_to([iniMatBracks[1].get_coord(0),iniMatRows[i].get_coord(1), 0]).shift(RIGHT*len(rRowNames[i])*iniMat.height*0.06) | |
.add_updater(lambda x, y = iniMatRows[i], b =iniMatBracks[1]: x.move_to([b.get_coord(0), y.get_coord(1), 0 ]).shift(RIGHT*len(rRowNames[i])*iniMat.height*0.06)) | |
for i in range(len(iniMatRows))]) | |
iniMatrix.add(iniMatRowNamesR) | |
# ------------> Create and write column names if col names are passed into 'colNames' argument and append it to Matrix with updaters | |
if colNames is not None: | |
iniMatCols = iniMat.get_columns() | |
iniMatColNames = VGroup( * [MathTex(colNames[i]) | |
.next_to(iniMatCols[i], UP, buff=iniMat.height/5) | |
.add_updater(lambda x, y = iniMatCols[i]: x.next_to(y, UP, buff=iniMat.height/7)) for i in range(len(iniMatCols))] ) | |
iniMatrix.add(iniMatColNames) | |
""" | |
Variable iniMatrix Structure based on indices (iniMatrix is a VGroup i.e a group of Mobjects added sequentially based on above functions): | |
indice : Mobject | |
0 : iniMat --> The main Matrix Mobject | |
1 : | |
""" | |
return iniMatrix, ini | |
class Figure10(MovingCameraScene): | |
def construct(self): | |
# INTRO ANIMATION GOES HERE | |
def myAnimation(wordsForIntro: str): | |
fontHeight = config.frame_height//8 | |
fontColor = WHITE | |
timePerChar = 0.1 | |
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3) | |
self.play(Broadcast(C), run_time=1) | |
self.add(C) | |
h = config.frame_height // 1.5 | |
svg2 = SVGMobject("/Users/arifwaghbakriwala/Desktop/Northeastern/Projects/Manimations/assets/svg/pl2.svg", | |
height=config.frame_height, | |
width= config.frame_width, | |
stroke_color=MAROON, | |
stroke_width=7, | |
fill_color=BLUE, | |
fill_opacity=0 | |
)#.to_edge(LEFT).rotate(180*DEGREES).flip(RIGHT) | |
self.play(Write(svg2)) | |
self.add(svg2) | |
# Building text objects of individual characters. | |
wordsMath = VGroup() | |
for word in wordsForIntro: | |
charTex = VGroup() | |
for i,ch in enumerate(word): | |
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight) | |
if i != 0: | |
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN) | |
else: | |
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN) | |
charTex.add(chTex) | |
wordsMath.add(charTex) | |
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION | |
for wInd in range(len(wordsMath)): | |
for chInd in range(len(wordsMath[wInd])): | |
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.wait(0.5) | |
for chInd in reversed(range(len(wordsMath[wInd]))): | |
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//5, run_time=1.5)) | |
self.play(AnimationGroup(*[ShrinkToCenter(C, run_time=0.75), ShrinkToCenter(svg2, run_time=0.75)])) | |
self.wait(0.5) | |
self.remove(C,svg2, wordsMath) | |
myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE']) | |
self.wait() | |
# ---------------> TITLE OF VIDEO | |
title = Title('Self Attention (Scaled Dot Product Attention)', match_underline_width_to_text=True, underline_buff= MED_LARGE_BUFF, font_size = 256).shift(DOWN*2) | |
self.play(DrawBorderThenFill(title), run_time = 2) | |
################################################################################ | |
# FUNCTION TO ANIMATE MATRIX MULTIPLICATION # | |
################################################################################ | |
def matrixMultiplicationAnimation(matrix_1: Matrix, matrix_2: Matrix, result: Matrix): | |
matrices = VGroup(*[matrix_1, matrix_2, result]) | |
# self.play(self.camera.frame.animate.move_to(matrices).set(width = matrices.width * 5, height = matrices.height*7)) | |
M1_rows = matrix_1.get_rows() | |
M2_cols = matrix_2.get_columns() | |
M3_rows = result.get_rows() | |
for rw_index in range(0, len(M1_rows)): | |
for cl_index in range(0, len(M2_cols)): | |
ag_final = AnimationGroup(*[Indicate(M1_rows[rw_index]), | |
Indicate(M2_cols[cl_index]), | |
FadeTransformPieces(mobject=VGroup(*[M1_rows[rw_index], M2_cols[cl_index]]).copy(), target_mobject=M3_rows[rw_index][cl_index]) | |
], lag_ratio=0) | |
self.play(ag_final, run_time=0.35) | |
################################################################################ | |
# FUNCTION TO PARENTHESISE MATRICES # | |
################################################################################ | |
def parenthesizeMatrix(matrix, encapsulateName: str, referenceMobject: Mobject = None, position = None, bf=0.5, initial_scale_factor: float = 1.5): | |
""" | |
matrix: VGroup that contains matrix, dimension and caption | |
encapsulationName: Name of function in which matrix needs to be presented like sine, cosine, etc | |
cl: cell in which the parenthesized matrix needs to fit in | |
""" | |
parenMat = VGroup() | |
parens = MathTex("(", ")") | |
parens.scale(initial_scale_factor) | |
parens.stretch_to_fit_height(matrix.height) | |
mt= matrix[0] | |
l_paren, r_paren = parens.split() | |
l_paren.next_to(mt, LEFT, buff=bf) | |
r_paren.next_to(mt, RIGHT, buff=bf) | |
tx = MathTex(encapsulateName) | |
tx.scale(initial_scale_factor) | |
tx.next_to(l_paren, LEFT, buff=bf) | |
parenMat+=matrix | |
parenMat.add(*[l_paren, r_paren, tx]) | |
# self.play(FadeTransformPieces(matrix, parenMat), run_time=0.75) | |
# if referenceMobject is not None: | |
# self.play(parenMat.animate.move_to(referenceMobject, position)) | |
return parenMat | |
# CALL TO FETCH THE FIGURE OF SCALED DOT PRODUCT ATTENTION FLOW CHART | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
scaleCaps, scaleBoxes, scaleArrows = scaledDotProductFigure() | |
unifiedMobject = VGroup(*[scaleCaps, scaleBoxes, scaleArrows]) | |
self.play(Write(unifiedMobject)) | |
self.play(scaleCaps.animate.move_to(ORIGIN).to_edge(RIGHT*5)) | |
self.wait() | |
# unifiedMobject.move_to(ORIGIN).to_edge(RIGHT*5).scale(0.75) | |
# self.add(unifiedMobject) | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
embeddingDimension = 6 | |
hiddenDimension = embeddingDimension | |
data = [['Attention', 'is', 'not', 'enough'], ['Towards', 'Data', 'Science'], ['It', 'was', 'a', 'really', 'good', 'article']] | |
inputDimension = len(data[0]) | |
sentences_Embedded = construct_Word_Embeddings(data, embeddingDimension) | |
sent, pos_enc, pos_encoded = construct_Positional_Encoding(sentences_Embedded[0]) | |
# CREATING INDIVIDUAL WORD VECTORS | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
_lRowNames = ["dim_%d"%i for i in range(embeddingDimension)] | |
_rRowNames = None | |
_colNames = data[0] | |
wordVectors = VGroup(*[ | |
buildMatrices( Rs= word.reshape(-1, 1).shape[0], | |
Cs= word.reshape(-1,1).shape[1], | |
roundDecimals=2, | |
arrayToMatrix=word.reshape(-1,1), | |
clr= GREEN_B, | |
hgt=6, | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames= [_cName] | |
)[0] | |
for _cName, word in zip(_colNames, sent)]).scale(2).arrange(RIGHT, buff=5) | |
self.play(Write(wordVectors)) | |
self.wait() | |
# CREATING WORD VECTORS MATRIX AND THAN TRANSFORMING INDIVIDUAL VECTORS TO WORD VECTORS MATRIX | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
_lRowNames = data[0] | |
_rRowNames = None | |
_colNames = ["d_%d"%i for i in range(embeddingDimension)] | |
wordMatrix = buildMatrices(Rs = sent.shape[0], | |
Cs= sent.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
arrayToMatrix=sent, | |
clr= GREEN_C, | |
cap= 'Word \: Vectors \: Matrix', | |
lRowNames= _lRowNames, | |
rRowNames= None, | |
colNames=_colNames, | |
wdt=2, | |
hgt=10 | |
)[0].scale(3).shift(LEFT*3) | |
self.play(FadeTransformPieces(wordVectors, wordMatrix), run_time = 2) | |
self.play(wordMatrix[0].animate.to_edge(LEFT*15)) | |
self.wait() | |
# CREATING ADDITION SIGN AND PLACING IT NEXT TO WORD VECTORS MATRIX | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
plusSign = MathTex("+", font_size = 300).scale(2).next_to(wordMatrix, RIGHT*3) | |
# CREATING POSITION ENCODING MATRIX AND PLACING IT NEXT TO PLUS SIGN | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
_lRowNames = ["position_%d"%i for i in range(embeddingDimension)] | |
_rRowNames = None | |
_colNames = ["d_%d"%i for i in range(embeddingDimension)] | |
posEncMatrix = buildMatrices(Rs = pos_enc.shape[0], | |
Cs= pos_enc.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
arrayToMatrix=pos_enc, | |
clr= GREEN_D, | |
cap= 'Position \: Vectors \: Matrix', | |
lRowNames= _lRowNames, | |
rRowNames= None, | |
colNames=_colNames, | |
wdt=2, | |
hgt=10 | |
)[0].scale(3).next_to(plusSign, RIGHT*3) | |
self.play(Write(plusSign), Write(posEncMatrix)) | |
self.play(VGroup(*[wordMatrix, plusSign, posEncMatrix]).animate.move_to(ORIGIN)) | |
self.wait() | |
# CREATING POSITION AWARE EMBEDDINGS AND TRANSFORMING THE TWO PRIOR MATRICES INTO THIS POSITION AWARE MATRIX | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
_lRowNames = data[0] | |
_rRowNames = ["position_%d"%i for i in range(embeddingDimension)] | |
_colNames = ["d_%d"%i for i in range(embeddingDimension)] | |
posEncodedMatrix = buildMatrices(Rs = pos_encoded.shape[0], | |
Cs= pos_encoded.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
arrayToMatrix=pos_encoded, | |
clr= GREEN_E, | |
cap= 'Position \: Aware \:Embeddings \: Matrix', | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=2, | |
hgt=10 | |
)[0].scale(3) | |
self.play(FadeTransformPieces(VGroup(*[wordMatrix, plusSign, posEncMatrix]), posEncodedMatrix), run_time = 2) | |
self.wait() | |
# CREATING QUERY, KEY AND VALUE MATRICES | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
# build query, key and values matrix with details like row and column names | |
_lRowNames = data[0] | |
_rRowNames = ["position_%d"%i for i in range(embeddingDimension)] | |
_colNames = ["d_%d"%i for i in range(embeddingDimension)] | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_qkvCaps = ['query', 'key', 'value'] | |
qkvWithCaps = VGroup(*[ buildMatrices(Rs = pos_encoded.shape[0], | |
Cs= pos_encoded.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
arrayToMatrix=pos_encoded, | |
clr= _colors[i], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=2, | |
hgt=10 | |
)[0].scale(2) | |
for i in range(len(_qkvCaps))]).arrange(DOWN, buff=4).to_edge(LEFT*4) | |
self.play(FadeTransformPieces(posEncodedMatrix, qkvWithCaps), run_time = 2) | |
self.wait() | |
# CREATING QUERY, KEY AND VALUE MATRICES (LESS DETAILS) | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
# build query, key and value matrix without much details and only dimensions and name of matrix | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_qkvCaps = ['query', 'key', 'value'] | |
qkvWithoutCaps = VGroup(*[ buildMatrices(Rs = pos_encoded.shape[0], | |
Cs= pos_encoded.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
arrayToMatrix=pos_encoded, | |
clr= _colors[i], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=2, | |
hgt=10 | |
)[0].scale(2) | |
for i in range(len(_qkvCaps))]).arrange(DOWN, buff=8).to_edge(LEFT*6) | |
self.play(FadeTransformPieces(qkvWithCaps, qkvWithoutCaps)) | |
self.wait() | |
query, key, value = qkvWithoutCaps | |
# self.add(qkvWithoutCaps) | |
# CREATING WEIGHT MATRICES TO PROJECT QUERY, KEY AND VALUE | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
# build query, key and value WEIGHTS matrix with only name and dimensions | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_qkvCaps = ['Query\:_{Weights}', 'Key\:_{Weights}', 'Value\:_{Weights}'] | |
uniform = t.distributions.uniform.Uniform(low=t.tensor(-1.0),high=t.tensor(1.0)) | |
weightsQKV = [buildMatrices(Rs =hiddenDimension, | |
Cs= hiddenDimension, | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = uniform.sample, | |
arrayToMatrix=None, | |
clr= _colors[i], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=2, | |
hgt=8, | |
seed=786+i+1 | |
) | |
for i in range(len(_qkvCaps))] | |
for i in range(len(weightsQKV)): | |
weightsQKV[i][0].scale(2) | |
weightsQKV[0][0].next_to(query, RIGHT*4) | |
weightsQKV[1][0].next_to(key, RIGHT*4) | |
weightsQKV[2][0].next_to(value, RIGHT*4) | |
self.play(Write(VGroup(*[weightsQKV[0][0], weightsQKV[1][0], weightsQKV[2][0]]))) | |
# self.add(weightsQKV[0][0], weightsQKV[1][0], weightsQKV[2][0]) | |
self.wait() | |
# PROJECTING QUERY, KEY AND VALUE USING WEIGHTS MATRICES (ALSO SOME ANIMATIONS) | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
# _qkvCaps = ['Query\:_{Projected} \: (Q)', 'Key\:_{Projected} \: (K)', 'Value\:_{Projected} \: (V)'] | |
_qkvCaps = ['(Q)', '(K)', '(V)'] | |
_projQKV = [pos_encoded @ weightsQKV[0][1], pos_encoded @ weightsQKV[1][1], pos_encoded @ weightsQKV[2][1] ] | |
projQKV = [buildMatrices(Rs = _projQKV[i].shape[0], | |
Cs= _projQKV[i].shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_projQKV[i], | |
clr= _colors[i], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=2, | |
hgt=10 | |
) | |
for i in range(len(_projQKV))] | |
for i in range(len(projQKV)): | |
projQKV[i][0].scale(2).next_to(weightsQKV[i][0], RIGHT*4) | |
for i,w,p in zip(range(3),weightsQKV, projQKV): | |
w = weightsQKV[i] | |
p = projQKV[i] | |
matrices = VGroup(*[qkvWithoutCaps[i], w[0], p[0]]) | |
self.play(FadeTransformPieces(matrices[0:2], matrices[2])) | |
moveAG = AnimationGroup(*[projQKV[i][0].animate.to_edge(LEFT*8) for i in range(len(projQKV))], lag_ratio=0.2) | |
self.play(moveAG) | |
self.wait() | |
self.remove(wordVectors, wordMatrix[0], plusSign, posEncMatrix[0], posEncodedMatrix[0], qkvWithCaps, qkvWithoutCaps, weightsQKV[0][0], weightsQKV[1][0], weightsQKV[1][0]) | |
# MATMUL OF PROJECTED QUERY AND KEY TRANSPOSE (Q and K) | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
_pKT = buildMatrices(Rs = _projQKV[1].shape[1], | |
Cs= _projQKV[1].shape[0], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=t.transpose(_projQKV[i], dim0=1, dim1=0), | |
clr= _colors[1], | |
cap= '(K^T)', | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=4, | |
hgt=8 | |
) | |
_pKT[0].scale(2).next_to(projQKV[1][0], RIGHT).to_edge(LEFT*8) | |
self.play(FadeTransformPieces(projQKV[1][0], _pKT[0])) | |
self.wait() | |
qkT = buildMatrices(Rs = projQKV[0][1].shape[0], | |
Cs= t.transpose(_projQKV[i], dim0=1, dim1=0).shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix= projQKV[0][1] @ t.transpose(_projQKV[i], dim0=1, dim1=0), | |
clr= ORANGE, | |
cap= '(Q\: . \:K^T)', | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=4, | |
hgt=8 | |
) | |
qkT[0].scale(2).next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*8) | |
self.play(Write(VGroup(*[qkT[0][0].get_brackets(),qkT[0][1],qkT[0][2]])), run_time=0.25) | |
matrixMultiplicationAnimation(matrix_1= projQKV[0][0][0], matrix_2= _pKT[0][0], result= qkT[0][0]) | |
self.wait() | |
# SCALING RESULT OF MATMUL BLOCK v1 | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_qkTSCaled = t.round(qkT[1], decimals=2).numpy().astype('str') | |
_qkTSCaledArray = np.vectorize(lambda x: x+'/\sqrt{d_{k}}')(_qkTSCaled) | |
qkTScaled_v1 = buildMatrices(Rs = _qkTSCaledArray.shape[0], | |
Cs= _qkTSCaledArray.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_qkTSCaledArray, | |
clr= YELLOW, | |
cap= "\\frac{(Q\: . \:K^T)}{\sqrt{d}}", | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=6, | |
hgt=10 | |
) | |
qkTScaled_v1 = (parenthesizeMatrix(matrix=qkTScaled_v1[0], encapsulateName='Scaled', bf=0.9), qkTScaled_v1[1]) | |
qkTScaled_v1[0].scale(2).next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*8) | |
self.play(FadeTransformPieces(qkT[0], qkTScaled_v1[0])) | |
self.wait() | |
# SCALING RESULT OF MATMUL BLOCK v2 | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_paren_qkTSCaled = t.round(qkT[1], decimals=2).numpy().astype('str') | |
_paren_qkTSCaledArray = np.vectorize(lambda x: x+'/%d'%hiddenDimension)(_paren_qkTSCaled) | |
qkTScaled_v2 = buildMatrices(Rs = _paren_qkTSCaledArray.shape[0], | |
Cs= _paren_qkTSCaledArray.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_paren_qkTSCaledArray, | |
clr= YELLOW, | |
cap= "\\frac{(Q\: . \:K^T)}{\sqrt{d}}", | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=4, | |
hgt=10 | |
) | |
qkTScaled_v2 = (parenthesizeMatrix(matrix=qkTScaled_v2[0], encapsulateName='Scaled', bf=0.9), qkTScaled_v2[1]) | |
qkTScaled_v2[0].scale(2).next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*8) | |
self.play(FadeTransformPieces(qkTScaled_v1[0], qkTScaled_v2[0])) | |
self.wait() | |
# SCALING RESULT OF MATMUL BLOCK v3 | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_paren_qkTSCaledArray = t.round(qkT[1]/hiddenDimension, decimals=2) | |
qkTScaled_v3= buildMatrices(Rs = _paren_qkTSCaledArray.shape[0], | |
Cs= _paren_qkTSCaledArray.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_paren_qkTSCaledArray, | |
clr= YELLOW, | |
cap= "\\frac{(Q\: . \:K^T)}{\sqrt{d}}", | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=4, | |
hgt=10 | |
) | |
qkTScaled_v3 = (parenthesizeMatrix(matrix=qkTScaled_v3[0], encapsulateName='Scaled', bf=0.9, initial_scale_factor=2), qkTScaled_v3[1]) | |
qkTScaled_v3[0].scale(2).next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*8) | |
self.play(FadeTransformPieces(qkTScaled_v2[0], qkTScaled_v3[0])) | |
self.wait() | |
qkTScaled= buildMatrices(Rs = _paren_qkTSCaledArray.shape[0], | |
Cs= _paren_qkTSCaledArray.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_paren_qkTSCaledArray, | |
clr= YELLOW, | |
cap= "\\frac{(Q\: . \:K^T)}{\sqrt{d}}", | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=4, | |
hgt=10 | |
) | |
qkTScaled[0].scale(2).next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*12) | |
self.play(FadeTransformPieces(qkTScaled_v3[0], qkTScaled[0])) | |
self.wait() | |
# CONVERTING SCALED PROJECTED MATRICES TO SOFTMAX MATRICES | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
scaledSoftmax = t.softmax(qkTScaled_v3[1], dim=-1) | |
scaledQKVSoftmax = buildMatrices(Rs = scaledSoftmax.shape[0], | |
Cs= scaledSoftmax.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=scaledSoftmax, | |
clr= PURPLE, | |
cap= "Softmax \: \\biggl( \\frac{(Q\: . \:K^T)}{\sqrt{d}} \\biggr)", | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=4, | |
hgt=10 | |
) | |
parenSoft = (parenthesizeMatrix(matrix=qkTScaled[0], encapsulateName='Softmax', bf=0.9, initial_scale_factor=2), qkTScaled[1]) | |
parenSoft[0].next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*8) | |
self.play(FadeTransformPieces(qkTScaled[0], parenSoft[0])) | |
self.wait() | |
scaledQKVSoftmax[0].scale(2).next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*8) | |
self.play(FadeTransformPieces(parenSoft[0], scaledQKVSoftmax[0])) | |
self.wait() | |
self.remove(qkT[0], qkTScaled[0], qkTScaled_v1[0], qkTScaled_v2[0], qkTScaled_v3[0], parenSoft[0]) | |
# FINAL MATMUL BLOCK | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_finalResult = scaledQKVSoftmax[1] @ projQKV[2][1] | |
finalResult = buildMatrices(Rs = _finalResult.shape[0], | |
Cs= _finalResult.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_finalResult, | |
clr= TEAL, | |
cap= "Softmax \: \\biggl( \\frac{(Q\: . \:K^T)}{sqrt{d}} \\biggr) \: . \: V", | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=3, | |
hgt=10 | |
) | |
projQKV[2][0].save_state() | |
self.play(projQKV[2][0].animate.next_to(scaledQKVSoftmax[0], RIGHT * 4)) | |
finalResult[0].scale(2).next_to(VGroup(*[scaledQKVSoftmax[0], projQKV[2][0]]), DOWN*10) | |
self.play(Write(VGroup(*[finalResult[0][0].get_brackets(),finalResult[0][1],finalResult[0][2]])), run_time=0.25) | |
matrixMultiplicationAnimation(matrix_1= scaledQKVSoftmax[0][0], matrix_2= projQKV[2][0][0], result= finalResult[0][0]) | |
self.play(Restore(projQKV[2][0])) | |
self.play(FadeOut(scaledQKVSoftmax[0])) | |
self.play(finalResult[0].animate.move_to(ORIGIN).scale(2)) | |
self.play(Indicate(finalResult[0]), run_time = 4) | |
self.wait(10) | |
self.remove(qkT[0], qkTScaled[0], qkTScaled_v1[0], qkTScaled_v2[0], qkTScaled_v3[0], parenSoft[0], scaledQKVSoftmax[0], projQKV[0][0], projQKV[1][0],_pKT[0], projQKV[2][0]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment