Skip to content

Instantly share code, notes, and snippets.

@arif9799
Created August 8, 2023 02:19
Show Gist options
  • Save arif9799/d61e2d6911be344ba28a6088ec0f2705 to your computer and use it in GitHub Desktop.
Save arif9799/d61e2d6911be344ba28a6088ec0f2705 to your computer and use it in GitHub Desktop.
Attention is not enough! Self-Attention via Scaled Dot Product Attention
from manim import *
from manim.utils.unit import Percent, Pixels
from colour import Color
import gensim
from gensim.models import Word2Vec
import numpy as np
import torch as t
config.frame_width = 16*7
config.frame_height = 9*7
HIGH = True
if HIGH:
print("Its high")
config.pixel_height = 2160
config.pixel_width = 3840
else:
print("Its low")
config.pixel_height = 1080
config.pixel_width = 1920
################################################################################
# Creating Word Vectors #
################################################################################
def construct_Word_Embeddings(sentences, dim):
data = sentences
dimension = dim
model = Word2Vec(data,
min_count = 1,
vector_size = dimension,
window = 3)
sequences = []
for seq in data:
sequence = []
for word in seq:
word_emb = model.wv.get_vector(word, True).reshape(1,-1)
word_emb = t.from_numpy(word_emb)
sequence.append(word_emb)
sequences.append(t.round(t.cat(sequence, dim=0), decimals=2))
return sequences
################################################################################
# Creating Positional Encoding Vector #
################################################################################
def construct_Positional_Encoding(input):
x = input.unsqueeze(0)
max_len = 10000
input_dimension = x.shape[-1]
seq_len = x.shape[-2]
dimension_indices = t.arange(0,input_dimension, 2)
denominator = t.pow(max_len, dimension_indices/input_dimension)
numerator = t.arange(0, seq_len, 1).reshape(seq_len, 1)
even_positions_encoded = t.sin(numerator/denominator)
odd_positions_encoded = t.cos(numerator/denominator)
combined = t.stack([even_positions_encoded, odd_positions_encoded], dim = 2)
pe = combined.reshape(1,seq_len,input_dimension)
pe = pe.squeeze(0)
x = x.squeeze(0)
pe = t.round(pe, decimals=2)
x_pe = t.round(x + pe ,decimals=2)
return x, pe, x_pe
################################################################################
# SCALED DOT PRODUCT FIGURE CODE BELOW #
################################################################################
def scaledDotProductFigure():
#`````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
sDPCaps = VGroup() # scaled dot product Captions
sDPBoxes = VGroup() # scaled dot product Boxes
sDPArrows = VGroup() # scaled dot product Boxes
vBuff = 18
boxBuff = MED_LARGE_BUFF
# Q
_Q = Text("Q", font_size=96, slant=OBLIQUE).shift(DOWN*10)
sDPCaps.add(_Q)
# K
_K = Text("K", font_size=96, slant=OBLIQUE).next_to(_Q, RIGHT*20)
sDPCaps.add(_K)
# Matmul box 1
_matMulTxt1 = Text("Matrix Multiplication 1", font_size=96, slant=OBLIQUE).next_to(VGroup(*[_Q, _K]), UP*vBuff)
_matMulBox1 = always_redraw(lambda : SurroundingRectangle(mobject=_matMulTxt1, color=PURPLE, fill_opacity = 0.25, fill_color = PURPLE, corner_radius=0.3, buff=boxBuff))
sDPCaps.add(_matMulTxt1)
sDPBoxes.add(_matMulBox1)
# Scale Box
_scaleTxt = Text("Scale", font_size=96, slant=OBLIQUE).next_to(_matMulTxt1, UP*vBuff)
_scaleBox = always_redraw(lambda : SurroundingRectangle(mobject=_scaleTxt, color=YELLOW, fill_opacity = 0.25, fill_color = YELLOW, corner_radius=0.3, buff=boxBuff))
sDPCaps.add(_scaleTxt)
sDPBoxes.add(_scaleBox)
# Mask Optional
_maskTxt = Text("Mask (Optional)", font_size=96, slant=OBLIQUE).next_to(_scaleTxt, UP*vBuff)
_maskBox = always_redraw(lambda : SurroundingRectangle(mobject=_maskTxt, color=PINK, fill_opacity = 0.25, fill_color = PINK, corner_radius=0.3, buff=boxBuff))
sDPCaps.add(_maskTxt)
sDPBoxes.add(_maskBox)
# Softmax Optional
_sFTxt = Text("Softmax", font_size=96, slant=OBLIQUE).next_to(_maskTxt, UP*vBuff)
_sFBox = always_redraw(lambda : SurroundingRectangle(mobject=_sFTxt, color=GREEN, fill_opacity = 0.25, fill_color = GREEN, corner_radius=0.3, buff=boxBuff))
sDPCaps.add(_sFTxt)
sDPBoxes.add(_sFBox)
# V
_V = Text("V", font_size=96, slant=OBLIQUE).next_to(_sFTxt, RIGHT*15)
sDPCaps.add(_V)
# Matmul box 2
_matMulTxt2 = Text("Matrix Multiplication 2", font_size=96, slant=OBLIQUE).next_to(VGroup(*[_V, _sFTxt]), UP*vBuff)
_matMulBox2 = always_redraw(lambda : SurroundingRectangle(mobject=_matMulTxt2, color=PURPLE, fill_opacity = 0.25, fill_color = PURPLE, corner_radius=0.3, buff=boxBuff))
sDPCaps.add(_matMulTxt2)
sDPBoxes.add(_matMulBox2)
# _Q to matmul box 1
_Q_to_matMulBox1 = always_redraw(lambda : Arrow(start = _Q.get_top(), end = [_Q.get_top()[0], _matMulBox1.get_bottom()[1],0], stroke_width=15, buff=0.2))
sDPArrows.add(_Q_to_matMulBox1)
# _K to matmul box 1
_K_to_matMulBox1 = always_redraw(lambda : Arrow(start = _K.get_top(), end = [_K.get_top()[0], _matMulBox1.get_bottom()[1],0], stroke_width=15, buff=0.2))
sDPArrows.add(_K_to_matMulBox1)
# matMul Box 1 to scaleBox
_matMulBox1_to_scaleBox = always_redraw(lambda : Arrow(start = _matMulBox1.get_top(), end = _scaleBox.get_bottom(), stroke_width=15, buff=0.2))
sDPArrows.add(_matMulBox1_to_scaleBox)
# scaleBox to MaskBox
_scaleBox_to_maskBox = always_redraw(lambda : Arrow(start = _scaleBox.get_top(), end = _maskBox.get_bottom(), stroke_width=15, buff=0.2))
sDPArrows.add(_scaleBox_to_maskBox)
# MaskBox to softmax
sFMax_to_scaleBox = always_redraw(lambda : Arrow(start = _maskBox.get_top(), end = _sFBox.get_bottom(), stroke_width=15, buff=0.2))
sDPArrows.add(sFMax_to_scaleBox)
# softmax to matmul box 2
_sFBox_to_matMulBox2 = always_redraw(lambda : Arrow(start = _sFBox.get_top(), end = [_sFBox.get_top()[0], _matMulBox2.get_bottom()[1],0], stroke_width=15, buff=0.2))
sDPArrows.add(_sFBox_to_matMulBox2)
# _V to matmul box 2
_V_to_matMulBox2 = always_redraw(lambda : Arrow(start = _V.get_top(), end = [_V.get_top()[0], _matMulBox2.get_bottom()[1],0], stroke_width=15, buff=0.2))
sDPArrows.add(_V_to_matMulBox2)
return sDPCaps, sDPBoxes, sDPArrows
################################################################################
# FUNCTION TO BUILD MATRICES WITH DIMENSIONS AND CAPTIONS #
################################################################################
def buildMatrices(Rs, # Rs: number of rows of matrix to build
Cs, # Cs: number of cols of matrix to build
roundDecimals, # roundDecimals: number of places to round off the values to
whereTo = None, # whereTo: Where to exactly place the matrix near the reference mobject (DOWN, UP, LEFT or RIGHT)
referenceMobj = None, # referenceMobj: The Mobject w.r.t which the matrix is to be placed on screen and add updater of it
arrayToMatrix = None, # arrayToMatrix: The array whose matrix Mobject is to be built
func = None, # func: if not an array (if arrayToMatrix is None), the function of torch using which array is to be generated to build matrix mobject
clr = WHITE, # clr: color of the matrix
cap = None, # cap: Caption or name of the matrix
hgt = 6, # hgt: height factor of Matrix
wdt = 1.5, # wdt: width factor of Matrix
seed = 786, # seed: the value to maintain stability of randomness
lRowNames = None, # row names on left side of the matrix -- if any
rRowNames = None, # row names on right side of the matrix -- if any
colNames = None, # column names of the matrix
):
if arrayToMatrix is None and func is None:
print("No Input to process and build a Matrix upon")
return None
# ------------> Set a random seed for weights initialization and build a matrix
iniMat= None
t.manual_seed(seed=seed)
# ini = t.round(func((Rs, Cs)), decimals=roundDecimals) if (arrayToMatrix is None) else t.round(arrayToMatrix, decimals= roundDecimals)
if isinstance(arrayToMatrix, t.Tensor) or (func is not None):
ini = t.round(func((Rs, Cs)), decimals=roundDecimals) if (arrayToMatrix is None) else t.round(arrayToMatrix, decimals= roundDecimals)
iniMat = Matrix(matrix=ini.numpy(), v_buff= hgt * Rs * 0.02, h_buff = wdt * Cs * 0.125).set(color=clr)
elif isinstance(arrayToMatrix, np.ndarray):
ini = arrayToMatrix
iniMat = Matrix(matrix=arrayToMatrix, v_buff= hgt * Rs * 0.02, h_buff = wdt * Cs * 0.125).set(color=clr)
else:
print("IN HERE")
pass
# ------------> locate the matrix somewhere w.r.t reference mobject
if referenceMobj is None:
pass
else:
iniMat.next_to(referenceMobj, whereTo)
iniMat.add_updater( lambda x, y = referenceMobj: x.next_to(y, whereTo))
# ------------> Create the MASTER MATRIX, the VGroup and add the main matrix
iniMatrix = VGroup()
iniMatrix.add(iniMat)
# ------------> Create dimensions text form and append it to Matrix with updaters
dimText = MathTex('(%d, %d)'%(Rs, Cs), font_size = 40).move_to(iniMat.get_critical_point(DR)+DOWN/2)
dimText.add_updater(lambda x, y = iniMat: x.move_to(y.get_critical_point(DR)+ (DOWN * iniMat.height * 0.1) ))
iniMatrix.add(dimText)
# ------------> Create caption if there is any text passed into 'cap' argument and append it to Matrix with updaters
if cap is not None:
capText = MathTex('%s'%(cap), font_size = 64).move_to(iniMat.get_critical_point(DOWN) + DOWN*1.25 )
capText.add_updater(lambda x, y = iniMat: x.move_to(y.get_critical_point(DOWN) + DOWN*(iniMat.height * 0.3) ))
iniMatrix.add(capText)
# ------------> Create and write left row names if row names are passed into 'lRowNames' argument and append it to Matrix with updaters
iniMatRows = iniMat.get_rows()
iniMatBracks = iniMat.get_brackets()
if lRowNames is not None:
iniMatRowNamesL =VGroup(*[MathTex(lRowNames[i])
.move_to([iniMatBracks[0].get_coord(0),iniMatRows[i].get_coord(1), 0]).shift(LEFT*len(lRowNames[i])*iniMat.height*0.06)
.add_updater(lambda x, y = iniMatRows[i], b =iniMatBracks[0]: x.move_to([b.get_coord(0), y.get_coord(1), 0 ]).shift(LEFT*len(lRowNames[i])*iniMat.height*0.06))
for i in range(len(iniMatRows))])
iniMatrix.add(iniMatRowNamesL)
# ------------> Create and write right row names if row names are passed into 'rRowNames' argument and append it to Matrix with updaters
if rRowNames is not None:
iniMatRowNamesR =VGroup(*[MathTex(rRowNames[i])
.move_to([iniMatBracks[1].get_coord(0),iniMatRows[i].get_coord(1), 0]).shift(RIGHT*len(rRowNames[i])*iniMat.height*0.06)
.add_updater(lambda x, y = iniMatRows[i], b =iniMatBracks[1]: x.move_to([b.get_coord(0), y.get_coord(1), 0 ]).shift(RIGHT*len(rRowNames[i])*iniMat.height*0.06))
for i in range(len(iniMatRows))])
iniMatrix.add(iniMatRowNamesR)
# ------------> Create and write column names if col names are passed into 'colNames' argument and append it to Matrix with updaters
if colNames is not None:
iniMatCols = iniMat.get_columns()
iniMatColNames = VGroup( * [MathTex(colNames[i])
.next_to(iniMatCols[i], UP, buff=iniMat.height/5)
.add_updater(lambda x, y = iniMatCols[i]: x.next_to(y, UP, buff=iniMat.height/7)) for i in range(len(iniMatCols))] )
iniMatrix.add(iniMatColNames)
"""
Variable iniMatrix Structure based on indices (iniMatrix is a VGroup i.e a group of Mobjects added sequentially based on above functions):
indice : Mobject
0 : iniMat --> The main Matrix Mobject
1 :
"""
return iniMatrix, ini
class Figure10(MovingCameraScene):
def construct(self):
# INTRO ANIMATION GOES HERE
def myAnimation(wordsForIntro: str):
fontHeight = config.frame_height//8
fontColor = WHITE
timePerChar = 0.1
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3)
self.play(Broadcast(C), run_time=1)
self.add(C)
h = config.frame_height // 1.5
svg2 = SVGMobject("/Users/arifwaghbakriwala/Desktop/Northeastern/Projects/Manimations/assets/svg/pl2.svg",
height=config.frame_height,
width= config.frame_width,
stroke_color=MAROON,
stroke_width=7,
fill_color=BLUE,
fill_opacity=0
)#.to_edge(LEFT).rotate(180*DEGREES).flip(RIGHT)
self.play(Write(svg2))
self.add(svg2)
# Building text objects of individual characters.
wordsMath = VGroup()
for word in wordsForIntro:
charTex = VGroup()
for i,ch in enumerate(word):
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight)
if i != 0:
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN)
else:
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN)
charTex.add(chTex)
wordsMath.add(charTex)
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION
for wInd in range(len(wordsMath)):
for chInd in range(len(wordsMath[wInd])):
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar))
self.wait(0.5)
for chInd in reversed(range(len(wordsMath[wInd]))):
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar))
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//5, run_time=1.5))
self.play(AnimationGroup(*[ShrinkToCenter(C, run_time=0.75), ShrinkToCenter(svg2, run_time=0.75)]))
self.wait(0.5)
self.remove(C,svg2, wordsMath)
myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE'])
self.wait()
# ---------------> TITLE OF VIDEO
title = Title('Self Attention (Scaled Dot Product Attention)', match_underline_width_to_text=True, underline_buff= MED_LARGE_BUFF, font_size = 256).shift(DOWN*2)
self.play(DrawBorderThenFill(title), run_time = 2)
################################################################################
# FUNCTION TO ANIMATE MATRIX MULTIPLICATION #
################################################################################
def matrixMultiplicationAnimation(matrix_1: Matrix, matrix_2: Matrix, result: Matrix):
matrices = VGroup(*[matrix_1, matrix_2, result])
# self.play(self.camera.frame.animate.move_to(matrices).set(width = matrices.width * 5, height = matrices.height*7))
M1_rows = matrix_1.get_rows()
M2_cols = matrix_2.get_columns()
M3_rows = result.get_rows()
for rw_index in range(0, len(M1_rows)):
for cl_index in range(0, len(M2_cols)):
ag_final = AnimationGroup(*[Indicate(M1_rows[rw_index]),
Indicate(M2_cols[cl_index]),
FadeTransformPieces(mobject=VGroup(*[M1_rows[rw_index], M2_cols[cl_index]]).copy(), target_mobject=M3_rows[rw_index][cl_index])
], lag_ratio=0)
self.play(ag_final, run_time=0.35)
################################################################################
# FUNCTION TO PARENTHESISE MATRICES #
################################################################################
def parenthesizeMatrix(matrix, encapsulateName: str, referenceMobject: Mobject = None, position = None, bf=0.5, initial_scale_factor: float = 1.5):
"""
matrix: VGroup that contains matrix, dimension and caption
encapsulationName: Name of function in which matrix needs to be presented like sine, cosine, etc
cl: cell in which the parenthesized matrix needs to fit in
"""
parenMat = VGroup()
parens = MathTex("(", ")")
parens.scale(initial_scale_factor)
parens.stretch_to_fit_height(matrix.height)
mt= matrix[0]
l_paren, r_paren = parens.split()
l_paren.next_to(mt, LEFT, buff=bf)
r_paren.next_to(mt, RIGHT, buff=bf)
tx = MathTex(encapsulateName)
tx.scale(initial_scale_factor)
tx.next_to(l_paren, LEFT, buff=bf)
parenMat+=matrix
parenMat.add(*[l_paren, r_paren, tx])
# self.play(FadeTransformPieces(matrix, parenMat), run_time=0.75)
# if referenceMobject is not None:
# self.play(parenMat.animate.move_to(referenceMobject, position))
return parenMat
# CALL TO FETCH THE FIGURE OF SCALED DOT PRODUCT ATTENTION FLOW CHART
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
scaleCaps, scaleBoxes, scaleArrows = scaledDotProductFigure()
unifiedMobject = VGroup(*[scaleCaps, scaleBoxes, scaleArrows])
self.play(Write(unifiedMobject))
self.play(scaleCaps.animate.move_to(ORIGIN).to_edge(RIGHT*5))
self.wait()
# unifiedMobject.move_to(ORIGIN).to_edge(RIGHT*5).scale(0.75)
# self.add(unifiedMobject)
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
embeddingDimension = 6
hiddenDimension = embeddingDimension
data = [['Attention', 'is', 'not', 'enough'], ['Towards', 'Data', 'Science'], ['It', 'was', 'a', 'really', 'good', 'article']]
inputDimension = len(data[0])
sentences_Embedded = construct_Word_Embeddings(data, embeddingDimension)
sent, pos_enc, pos_encoded = construct_Positional_Encoding(sentences_Embedded[0])
# CREATING INDIVIDUAL WORD VECTORS
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
_lRowNames = ["dim_%d"%i for i in range(embeddingDimension)]
_rRowNames = None
_colNames = data[0]
wordVectors = VGroup(*[
buildMatrices( Rs= word.reshape(-1, 1).shape[0],
Cs= word.reshape(-1,1).shape[1],
roundDecimals=2,
arrayToMatrix=word.reshape(-1,1),
clr= GREEN_B,
hgt=6,
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames= [_cName]
)[0]
for _cName, word in zip(_colNames, sent)]).scale(2).arrange(RIGHT, buff=5)
self.play(Write(wordVectors))
self.wait()
# CREATING WORD VECTORS MATRIX AND THAN TRANSFORMING INDIVIDUAL VECTORS TO WORD VECTORS MATRIX
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
_lRowNames = data[0]
_rRowNames = None
_colNames = ["d_%d"%i for i in range(embeddingDimension)]
wordMatrix = buildMatrices(Rs = sent.shape[0],
Cs= sent.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
arrayToMatrix=sent,
clr= GREEN_C,
cap= 'Word \: Vectors \: Matrix',
lRowNames= _lRowNames,
rRowNames= None,
colNames=_colNames,
wdt=2,
hgt=10
)[0].scale(3).shift(LEFT*3)
self.play(FadeTransformPieces(wordVectors, wordMatrix), run_time = 2)
self.play(wordMatrix[0].animate.to_edge(LEFT*15))
self.wait()
# CREATING ADDITION SIGN AND PLACING IT NEXT TO WORD VECTORS MATRIX
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
plusSign = MathTex("+", font_size = 300).scale(2).next_to(wordMatrix, RIGHT*3)
# CREATING POSITION ENCODING MATRIX AND PLACING IT NEXT TO PLUS SIGN
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
_lRowNames = ["position_%d"%i for i in range(embeddingDimension)]
_rRowNames = None
_colNames = ["d_%d"%i for i in range(embeddingDimension)]
posEncMatrix = buildMatrices(Rs = pos_enc.shape[0],
Cs= pos_enc.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
arrayToMatrix=pos_enc,
clr= GREEN_D,
cap= 'Position \: Vectors \: Matrix',
lRowNames= _lRowNames,
rRowNames= None,
colNames=_colNames,
wdt=2,
hgt=10
)[0].scale(3).next_to(plusSign, RIGHT*3)
self.play(Write(plusSign), Write(posEncMatrix))
self.play(VGroup(*[wordMatrix, plusSign, posEncMatrix]).animate.move_to(ORIGIN))
self.wait()
# CREATING POSITION AWARE EMBEDDINGS AND TRANSFORMING THE TWO PRIOR MATRICES INTO THIS POSITION AWARE MATRIX
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
_lRowNames = data[0]
_rRowNames = ["position_%d"%i for i in range(embeddingDimension)]
_colNames = ["d_%d"%i for i in range(embeddingDimension)]
posEncodedMatrix = buildMatrices(Rs = pos_encoded.shape[0],
Cs= pos_encoded.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
arrayToMatrix=pos_encoded,
clr= GREEN_E,
cap= 'Position \: Aware \:Embeddings \: Matrix',
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=2,
hgt=10
)[0].scale(3)
self.play(FadeTransformPieces(VGroup(*[wordMatrix, plusSign, posEncMatrix]), posEncodedMatrix), run_time = 2)
self.wait()
# CREATING QUERY, KEY AND VALUE MATRICES
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
# build query, key and values matrix with details like row and column names
_lRowNames = data[0]
_rRowNames = ["position_%d"%i for i in range(embeddingDimension)]
_colNames = ["d_%d"%i for i in range(embeddingDimension)]
_colors = [RED_E, GREEN_E, BLUE_E]
_qkvCaps = ['query', 'key', 'value']
qkvWithCaps = VGroup(*[ buildMatrices(Rs = pos_encoded.shape[0],
Cs= pos_encoded.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
arrayToMatrix=pos_encoded,
clr= _colors[i],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=2,
hgt=10
)[0].scale(2)
for i in range(len(_qkvCaps))]).arrange(DOWN, buff=4).to_edge(LEFT*4)
self.play(FadeTransformPieces(posEncodedMatrix, qkvWithCaps), run_time = 2)
self.wait()
# CREATING QUERY, KEY AND VALUE MATRICES (LESS DETAILS)
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
# build query, key and value matrix without much details and only dimensions and name of matrix
_lRowNames = None
_rRowNames = None
_colNames = None
_colors = [RED_E, GREEN_E, BLUE_E]
_qkvCaps = ['query', 'key', 'value']
qkvWithoutCaps = VGroup(*[ buildMatrices(Rs = pos_encoded.shape[0],
Cs= pos_encoded.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
arrayToMatrix=pos_encoded,
clr= _colors[i],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=2,
hgt=10
)[0].scale(2)
for i in range(len(_qkvCaps))]).arrange(DOWN, buff=8).to_edge(LEFT*6)
self.play(FadeTransformPieces(qkvWithCaps, qkvWithoutCaps))
self.wait()
query, key, value = qkvWithoutCaps
# self.add(qkvWithoutCaps)
# CREATING WEIGHT MATRICES TO PROJECT QUERY, KEY AND VALUE
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
# build query, key and value WEIGHTS matrix with only name and dimensions
_lRowNames = None
_rRowNames = None
_colNames = None
_colors = [RED_E, GREEN_E, BLUE_E]
_qkvCaps = ['Query\:_{Weights}', 'Key\:_{Weights}', 'Value\:_{Weights}']
uniform = t.distributions.uniform.Uniform(low=t.tensor(-1.0),high=t.tensor(1.0))
weightsQKV = [buildMatrices(Rs =hiddenDimension,
Cs= hiddenDimension,
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = uniform.sample,
arrayToMatrix=None,
clr= _colors[i],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=2,
hgt=8,
seed=786+i+1
)
for i in range(len(_qkvCaps))]
for i in range(len(weightsQKV)):
weightsQKV[i][0].scale(2)
weightsQKV[0][0].next_to(query, RIGHT*4)
weightsQKV[1][0].next_to(key, RIGHT*4)
weightsQKV[2][0].next_to(value, RIGHT*4)
self.play(Write(VGroup(*[weightsQKV[0][0], weightsQKV[1][0], weightsQKV[2][0]])))
# self.add(weightsQKV[0][0], weightsQKV[1][0], weightsQKV[2][0])
self.wait()
# PROJECTING QUERY, KEY AND VALUE USING WEIGHTS MATRICES (ALSO SOME ANIMATIONS)
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
_lRowNames = None
_rRowNames = None
_colNames = None
_colors = [RED_E, GREEN_E, BLUE_E]
# _qkvCaps = ['Query\:_{Projected} \: (Q)', 'Key\:_{Projected} \: (K)', 'Value\:_{Projected} \: (V)']
_qkvCaps = ['(Q)', '(K)', '(V)']
_projQKV = [pos_encoded @ weightsQKV[0][1], pos_encoded @ weightsQKV[1][1], pos_encoded @ weightsQKV[2][1] ]
projQKV = [buildMatrices(Rs = _projQKV[i].shape[0],
Cs= _projQKV[i].shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_projQKV[i],
clr= _colors[i],
cap= _qkvCaps[i],
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=2,
hgt=10
)
for i in range(len(_projQKV))]
for i in range(len(projQKV)):
projQKV[i][0].scale(2).next_to(weightsQKV[i][0], RIGHT*4)
for i,w,p in zip(range(3),weightsQKV, projQKV):
w = weightsQKV[i]
p = projQKV[i]
matrices = VGroup(*[qkvWithoutCaps[i], w[0], p[0]])
self.play(FadeTransformPieces(matrices[0:2], matrices[2]))
moveAG = AnimationGroup(*[projQKV[i][0].animate.to_edge(LEFT*8) for i in range(len(projQKV))], lag_ratio=0.2)
self.play(moveAG)
self.wait()
self.remove(wordVectors, wordMatrix[0], plusSign, posEncMatrix[0], posEncodedMatrix[0], qkvWithCaps, qkvWithoutCaps, weightsQKV[0][0], weightsQKV[1][0], weightsQKV[1][0])
# MATMUL OF PROJECTED QUERY AND KEY TRANSPOSE (Q and K)
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
_pKT = buildMatrices(Rs = _projQKV[1].shape[1],
Cs= _projQKV[1].shape[0],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=t.transpose(_projQKV[i], dim0=1, dim1=0),
clr= _colors[1],
cap= '(K^T)',
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=4,
hgt=8
)
_pKT[0].scale(2).next_to(projQKV[1][0], RIGHT).to_edge(LEFT*8)
self.play(FadeTransformPieces(projQKV[1][0], _pKT[0]))
self.wait()
qkT = buildMatrices(Rs = projQKV[0][1].shape[0],
Cs= t.transpose(_projQKV[i], dim0=1, dim1=0).shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix= projQKV[0][1] @ t.transpose(_projQKV[i], dim0=1, dim1=0),
clr= ORANGE,
cap= '(Q\: . \:K^T)',
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=4,
hgt=8
)
qkT[0].scale(2).next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*8)
self.play(Write(VGroup(*[qkT[0][0].get_brackets(),qkT[0][1],qkT[0][2]])), run_time=0.25)
matrixMultiplicationAnimation(matrix_1= projQKV[0][0][0], matrix_2= _pKT[0][0], result= qkT[0][0])
self.wait()
# SCALING RESULT OF MATMUL BLOCK v1
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
_lRowNames = None
_rRowNames = None
_colNames = None
_qkTSCaled = t.round(qkT[1], decimals=2).numpy().astype('str')
_qkTSCaledArray = np.vectorize(lambda x: x+'/\sqrt{d_{k}}')(_qkTSCaled)
qkTScaled_v1 = buildMatrices(Rs = _qkTSCaledArray.shape[0],
Cs= _qkTSCaledArray.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_qkTSCaledArray,
clr= YELLOW,
cap= "\\frac{(Q\: . \:K^T)}{\sqrt{d}}",
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=6,
hgt=10
)
qkTScaled_v1 = (parenthesizeMatrix(matrix=qkTScaled_v1[0], encapsulateName='Scaled', bf=0.9), qkTScaled_v1[1])
qkTScaled_v1[0].scale(2).next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*8)
self.play(FadeTransformPieces(qkT[0], qkTScaled_v1[0]))
self.wait()
# SCALING RESULT OF MATMUL BLOCK v2
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
_lRowNames = None
_rRowNames = None
_colNames = None
_paren_qkTSCaled = t.round(qkT[1], decimals=2).numpy().astype('str')
_paren_qkTSCaledArray = np.vectorize(lambda x: x+'/%d'%hiddenDimension)(_paren_qkTSCaled)
qkTScaled_v2 = buildMatrices(Rs = _paren_qkTSCaledArray.shape[0],
Cs= _paren_qkTSCaledArray.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_paren_qkTSCaledArray,
clr= YELLOW,
cap= "\\frac{(Q\: . \:K^T)}{\sqrt{d}}",
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=4,
hgt=10
)
qkTScaled_v2 = (parenthesizeMatrix(matrix=qkTScaled_v2[0], encapsulateName='Scaled', bf=0.9), qkTScaled_v2[1])
qkTScaled_v2[0].scale(2).next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*8)
self.play(FadeTransformPieces(qkTScaled_v1[0], qkTScaled_v2[0]))
self.wait()
# SCALING RESULT OF MATMUL BLOCK v3
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
_lRowNames = None
_rRowNames = None
_colNames = None
_paren_qkTSCaledArray = t.round(qkT[1]/hiddenDimension, decimals=2)
qkTScaled_v3= buildMatrices(Rs = _paren_qkTSCaledArray.shape[0],
Cs= _paren_qkTSCaledArray.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_paren_qkTSCaledArray,
clr= YELLOW,
cap= "\\frac{(Q\: . \:K^T)}{\sqrt{d}}",
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=4,
hgt=10
)
qkTScaled_v3 = (parenthesizeMatrix(matrix=qkTScaled_v3[0], encapsulateName='Scaled', bf=0.9, initial_scale_factor=2), qkTScaled_v3[1])
qkTScaled_v3[0].scale(2).next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*8)
self.play(FadeTransformPieces(qkTScaled_v2[0], qkTScaled_v3[0]))
self.wait()
qkTScaled= buildMatrices(Rs = _paren_qkTSCaledArray.shape[0],
Cs= _paren_qkTSCaledArray.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_paren_qkTSCaledArray,
clr= YELLOW,
cap= "\\frac{(Q\: . \:K^T)}{\sqrt{d}}",
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=4,
hgt=10
)
qkTScaled[0].scale(2).next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*12)
self.play(FadeTransformPieces(qkTScaled_v3[0], qkTScaled[0]))
self.wait()
# CONVERTING SCALED PROJECTED MATRICES TO SOFTMAX MATRICES
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
_lRowNames = None
_rRowNames = None
_colNames = None
scaledSoftmax = t.softmax(qkTScaled_v3[1], dim=-1)
scaledQKVSoftmax = buildMatrices(Rs = scaledSoftmax.shape[0],
Cs= scaledSoftmax.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=scaledSoftmax,
clr= PURPLE,
cap= "Softmax \: \\biggl( \\frac{(Q\: . \:K^T)}{\sqrt{d}} \\biggr)",
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=4,
hgt=10
)
parenSoft = (parenthesizeMatrix(matrix=qkTScaled[0], encapsulateName='Softmax', bf=0.9, initial_scale_factor=2), qkTScaled[1])
parenSoft[0].next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*8)
self.play(FadeTransformPieces(qkTScaled[0], parenSoft[0]))
self.wait()
scaledQKVSoftmax[0].scale(2).next_to(VGroup(*[projQKV[0][0], _pKT[0]]), RIGHT*8)
self.play(FadeTransformPieces(parenSoft[0], scaledQKVSoftmax[0]))
self.wait()
self.remove(qkT[0], qkTScaled[0], qkTScaled_v1[0], qkTScaled_v2[0], qkTScaled_v3[0], parenSoft[0])
# FINAL MATMUL BLOCK
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
_lRowNames = None
_rRowNames = None
_colNames = None
_finalResult = scaledQKVSoftmax[1] @ projQKV[2][1]
finalResult = buildMatrices(Rs = _finalResult.shape[0],
Cs= _finalResult.shape[1],
roundDecimals=2,
whereTo=None,
referenceMobj=None,
func = None,
arrayToMatrix=_finalResult,
clr= TEAL,
cap= "Softmax \: \\biggl( \\frac{(Q\: . \:K^T)}{sqrt{d}} \\biggr) \: . \: V",
lRowNames= _lRowNames,
rRowNames= _rRowNames,
colNames=_colNames,
wdt=3,
hgt=10
)
projQKV[2][0].save_state()
self.play(projQKV[2][0].animate.next_to(scaledQKVSoftmax[0], RIGHT * 4))
finalResult[0].scale(2).next_to(VGroup(*[scaledQKVSoftmax[0], projQKV[2][0]]), DOWN*10)
self.play(Write(VGroup(*[finalResult[0][0].get_brackets(),finalResult[0][1],finalResult[0][2]])), run_time=0.25)
matrixMultiplicationAnimation(matrix_1= scaledQKVSoftmax[0][0], matrix_2= projQKV[2][0][0], result= finalResult[0][0])
self.play(Restore(projQKV[2][0]))
self.play(FadeOut(scaledQKVSoftmax[0]))
self.play(finalResult[0].animate.move_to(ORIGIN).scale(2))
self.play(Indicate(finalResult[0]), run_time = 4)
self.wait(10)
self.remove(qkT[0], qkTScaled[0], qkTScaled_v1[0], qkTScaled_v2[0], qkTScaled_v3[0], parenSoft[0], scaledQKVSoftmax[0], projQKV[0][0], projQKV[1][0],_pKT[0], projQKV[2][0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment