Created
August 9, 2023 21:33
-
-
Save arif9799/0b652e1e3dcf7ec0f1db40b5a8f627ab to your computer and use it in GitHub Desktop.
Attention is not enough! Multihead Attention in Transformers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from manim import * | |
from manim.utils.unit import Percent, Pixels | |
from colour import Color | |
import gensim | |
from gensim.models import Word2Vec | |
import numpy as np | |
import torch as t | |
config.frame_width = 16*8 | |
config.frame_height = 9*8 | |
HIGH = True | |
if HIGH: | |
print("Its high") | |
config.pixel_height = 2160 | |
config.pixel_width = 3840 | |
else: | |
print("Its low") | |
config.pixel_height = 1080 | |
config.pixel_width = 1920 | |
################################################################################ | |
# Creating Word Vectors # | |
################################################################################ | |
def construct_Word_Embeddings(sentences, dim): | |
data = sentences | |
dimension = dim | |
model = Word2Vec(data, | |
min_count = 1, | |
vector_size = dimension, | |
window = 3) | |
sequences = [] | |
for seq in data: | |
sequence = [] | |
for word in seq: | |
word_emb = model.wv.get_vector(word, True).reshape(1,-1) | |
word_emb = t.from_numpy(word_emb) | |
sequence.append(word_emb) | |
sequences.append(t.round(t.cat(sequence, dim=0), decimals=2)) | |
return sequences | |
################################################################################ | |
# Creating Positional Encoding Vector # | |
################################################################################ | |
def construct_Positional_Encoding(input): | |
x = input.unsqueeze(0) | |
max_len = 10000 | |
input_dimension = x.shape[-1] | |
seq_len = x.shape[-2] | |
dimension_indices = t.arange(0,input_dimension, 2) | |
denominator = t.pow(max_len, dimension_indices/input_dimension) | |
numerator = t.arange(0, seq_len, 1).reshape(seq_len, 1) | |
even_positions_encoded = t.sin(numerator/denominator) | |
odd_positions_encoded = t.cos(numerator/denominator) | |
combined = t.stack([even_positions_encoded, odd_positions_encoded], dim = 2) | |
pe = combined.reshape(1,seq_len,input_dimension) | |
pe = pe.squeeze(0) | |
x = x.squeeze(0) | |
pe = t.round(pe, decimals=2) | |
x_pe = t.round(x + pe ,decimals=2) | |
return x, pe, x_pe | |
################################################################################ | |
# FUNCTION TO BUILD MATRICES WITH DIMENSIONS AND CAPTIONS # | |
################################################################################ | |
def buildMatrices(Rs, # Rs: number of rows of matrix to build | |
Cs, # Cs: number of cols of matrix to build | |
roundDecimals, # roundDecimals: number of places to round off the values to | |
whereTo = None, # whereTo: Where to exactly place the matrix near the reference mobject (DOWN, UP, LEFT or RIGHT) | |
referenceMobj = None, # referenceMobj: The Mobject w.r.t which the matrix is to be placed on screen and add updater of it | |
arrayToMatrix = None, # arrayToMatrix: The array whose matrix Mobject is to be built | |
func = None, # func: if not an array (if arrayToMatrix is None), the function of torch using which array is to be generated to build matrix mobject | |
clr = WHITE, # clr: color of the matrix | |
cap = None, # cap: Caption or name of the matrix | |
hgt = 6, # hgt: height factor of Matrix | |
wdt = 1.5, # wdt: width factor of Matrix | |
seed = 786, # seed: the value to maintain stability of randomness | |
lRowNames = None, # row names on left side of the matrix -- if any | |
rRowNames = None, # row names on right side of the matrix -- if any | |
colNames = None, # column names of the matrix | |
): | |
if arrayToMatrix is None and func is None: | |
print("No Input to process and build a Matrix upon") | |
return None | |
# ------------> Set a random seed for weights initialization and build a matrix | |
iniMat= None | |
t.manual_seed(seed=seed) | |
# ini = t.round(func((Rs, Cs)), decimals=roundDecimals) if (arrayToMatrix is None) else t.round(arrayToMatrix, decimals= roundDecimals) | |
if isinstance(arrayToMatrix, t.Tensor) or (func is not None): | |
ini = t.round(func((Rs, Cs)), decimals=roundDecimals) if (arrayToMatrix is None) else t.round(arrayToMatrix, decimals= roundDecimals) | |
iniMat = Matrix(matrix=ini.numpy(), v_buff= hgt * Rs * 0.02, h_buff = wdt * Cs * 0.125).set(color=clr) | |
elif isinstance(arrayToMatrix, np.ndarray): | |
ini = arrayToMatrix | |
iniMat = Matrix(matrix=arrayToMatrix, v_buff= hgt * Rs * 0.02, h_buff = wdt * Cs * 0.125).set(color=clr) | |
else: | |
print("IN HERE") | |
pass | |
# ------------> locate the matrix somewhere w.r.t reference mobject | |
if referenceMobj is None: | |
pass | |
else: | |
iniMat.next_to(referenceMobj, whereTo) | |
iniMat.add_updater( lambda x, y = referenceMobj: x.next_to(y, whereTo)) | |
# ------------> Create the MASTER MATRIX, the VGroup and add the main matrix | |
iniMatrix = VGroup() | |
iniMatrix.add(iniMat) | |
# ------------> Create dimensions text form and append it to Matrix with updaters | |
dimText = MathTex('(%d, %d)'%(Rs, Cs), font_size = 40).move_to(iniMat.get_critical_point(DR)+DOWN/2) | |
dimText.add_updater(lambda x, y = iniMat: x.move_to(y.get_critical_point(DR)+ (DOWN * iniMat.height * 0.1) )) | |
iniMatrix.add(dimText) | |
# ------------> Create caption if there is any text passed into 'cap' argument and append it to Matrix with updaters | |
if cap is not None: | |
capText = MathTex('%s'%(cap), font_size = 64).move_to(iniMat.get_critical_point(DOWN) + DOWN*1.25 ) | |
capText.add_updater(lambda x, y = iniMat: x.move_to(y.get_critical_point(DOWN) + DOWN*(iniMat.height * 0.4) )) | |
iniMatrix.add(capText) | |
# ------------> Create and write left row names if row names are passed into 'lRowNames' argument and append it to Matrix with updaters | |
iniMatRows = iniMat.get_rows() | |
iniMatBracks = iniMat.get_brackets() | |
if lRowNames is not None: | |
iniMatRowNamesL =VGroup(*[MathTex(lRowNames[i]) | |
.move_to([iniMatBracks[0].get_coord(0),iniMatRows[i].get_coord(1), 0]).shift(LEFT*len(lRowNames[i])*iniMat.height*0.06) | |
.add_updater(lambda x, y = iniMatRows[i], b =iniMatBracks[0]: x.move_to([b.get_coord(0), y.get_coord(1), 0 ]).shift(LEFT*len(lRowNames[i])*iniMat.height*0.06)) | |
for i in range(len(iniMatRows))]) | |
iniMatrix.add(iniMatRowNamesL) | |
# ------------> Create and write right row names if row names are passed into 'rRowNames' argument and append it to Matrix with updaters | |
if rRowNames is not None: | |
iniMatRowNamesR =VGroup(*[MathTex(rRowNames[i]) | |
.move_to([iniMatBracks[1].get_coord(0),iniMatRows[i].get_coord(1), 0]).shift(RIGHT*len(rRowNames[i])*iniMat.height*0.06) | |
.add_updater(lambda x, y = iniMatRows[i], b =iniMatBracks[1]: x.move_to([b.get_coord(0), y.get_coord(1), 0 ]).shift(RIGHT*len(rRowNames[i])*iniMat.height*0.06)) | |
for i in range(len(iniMatRows))]) | |
iniMatrix.add(iniMatRowNamesR) | |
# ------------> Create and write column names if col names are passed into 'colNames' argument and append it to Matrix with updaters | |
if colNames is not None: | |
iniMatCols = iniMat.get_columns() | |
iniMatColNames = VGroup( * [MathTex(colNames[i]) | |
.next_to(iniMatCols[i], UP, buff=iniMat.height/5) | |
.add_updater(lambda x, y = iniMatCols[i]: x.next_to(y, UP, buff=iniMat.height/7)) for i in range(len(iniMatCols))] ) | |
iniMatrix.add(iniMatColNames) | |
return iniMatrix, ini | |
class Figure11(MovingCameraScene): | |
def construct(self): | |
# INTRO ANIMATION GOES HERE | |
def myAnimation(wordsForIntro: str): | |
fontHeight = config.frame_height//8 | |
fontColor = WHITE | |
timePerChar = 0.1 | |
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3) | |
self.play(Broadcast(C), run_time=1) | |
self.add(C) | |
h = config.frame_height // 1.5 | |
svg2 = SVGMobject("/Users/arifwaghbakriwala/Desktop/Northeastern/Projects/Manimations/assets/svg/pl2.svg", | |
height=config.frame_height, | |
width= config.frame_width, | |
stroke_color=MAROON, | |
stroke_width=7, | |
fill_color=BLUE, | |
fill_opacity=0 | |
) | |
self.play(Write(svg2)) | |
self.add(svg2) | |
# Building text objects of individual characters. | |
wordsMath = VGroup() | |
for word in wordsForIntro: | |
charTex = VGroup() | |
for i,ch in enumerate(word): | |
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight) | |
if i != 0: | |
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN) | |
else: | |
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN) | |
charTex.add(chTex) | |
wordsMath.add(charTex) | |
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION | |
for wInd in range(len(wordsMath)): | |
for chInd in range(len(wordsMath[wInd])): | |
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.wait(0.5) | |
for chInd in reversed(range(len(wordsMath[wInd]))): | |
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//5, run_time=1.5)) | |
self.play(AnimationGroup(*[ShrinkToCenter(C, run_time=0.75), ShrinkToCenter(svg2, run_time=0.75)])) | |
self.wait(0.5) | |
self.remove(C,svg2, wordsMath) | |
myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE']) | |
self.wait() | |
################################################################################ | |
# FUNCTION TO ANIMATE MATRIX MULTIPLICATION # | |
################################################################################ | |
def matrixMultiplicationAnimation(matrix_1: Matrix, matrix_2: Matrix, result: Matrix): | |
matrices = VGroup(*[matrix_1, matrix_2, result]) | |
# self.play(self.camera.frame.animate.move_to(matrices).set(width = matrices.width * 5, height = matrices.height*7)) | |
M1_rows = matrix_1.get_rows() | |
M2_cols = matrix_2.get_columns() | |
M3_rows = result.get_rows() | |
for rw_index in range(0, len(M1_rows)): | |
for cl_index in range(0, len(M2_cols)): | |
ag_final = AnimationGroup(*[Indicate(M1_rows[rw_index]), | |
Indicate(M2_cols[cl_index]), | |
FadeTransformPieces(mobject=VGroup(*[M1_rows[rw_index], M2_cols[cl_index]]).copy(), target_mobject=M3_rows[rw_index][cl_index]) | |
], lag_ratio=0) | |
self.play(ag_final, run_time=0.35) | |
################################################################################ | |
# FUNCTION TO PARENTHESISE MATRICES # | |
################################################################################ | |
def parenthesizeMatrix(matrix, encapsulateName: str, referenceMobject: Mobject = None, position = None, bf=0.5, initial_scale_factor: float = 1.5): | |
""" | |
matrix: VGroup that contains matrix, dimension and caption | |
encapsulationName: Name of function in which matrix needs to be presented like sine, cosine, etc | |
cl: cell in which the parenthesized matrix needs to fit in | |
""" | |
parenMat = VGroup() | |
parens = MathTex("(", ")") | |
parens.scale(initial_scale_factor) | |
parens.stretch_to_fit_height(matrix.height) | |
mt= matrix[0] | |
l_paren, r_paren = parens.split() | |
l_paren.next_to(mt, LEFT, buff=bf) | |
r_paren.next_to(mt, RIGHT, buff=bf) | |
tx = MathTex(encapsulateName) | |
tx.scale(initial_scale_factor) | |
tx.next_to(l_paren, LEFT, buff=bf) | |
parenMat+=matrix | |
parenMat.add(*[l_paren, r_paren, tx]) | |
# self.play(FadeTransformPieces(matrix, parenMat), run_time=0.75) | |
# if referenceMobject is not None: | |
# self.play(parenMat.animate.move_to(referenceMobject, position)) | |
return parenMat | |
############################################################################################################################################################## | |
# # | |
# # | |
# IT STARTS HERE!!!! # | |
# # | |
# # | |
############################################################################################################################################################## | |
embeddingDimension = 6 | |
hiddenDimension = embeddingDimension | |
data = [['Attention', 'is', 'not', 'enough'], ['Towards', 'Data', 'Science'], ['It', 'was', 'a', 'really', 'good', 'article']] | |
inputDimension = len(data[0]) | |
sentences_Embedded = construct_Word_Embeddings(data, embeddingDimension) | |
_, _, pos_encoded = construct_Positional_Encoding(sentences_Embedded[0]) | |
# ---------------> TITLE OF VIDEO | |
title = Title('Multi-Head Attention (Exemplifying One Input)', match_underline_width_to_text=True, underline_buff= MED_LARGE_BUFF, font_size = 256).shift(DOWN*2) | |
self.play(DrawBorderThenFill(title), run_time = 2) | |
# CREATING QUERY, KEY AND VALUE MATRICES (WITH CAPS) | |
# build query, key and values matrix with details like row and column names | |
_lRowNames = data[0] | |
_rRowNames = ["position_%d"%i for i in range(embeddingDimension)] | |
_colNames = ["d_%d"%i for i in range(embeddingDimension)] | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_qkvCaps = ['query', 'key', 'value'] | |
qkvWithCaps = VGroup(*[ buildMatrices(Rs = pos_encoded.shape[0], | |
Cs= pos_encoded.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
arrayToMatrix=pos_encoded, | |
clr= _colors[i], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=2, | |
hgt=10 | |
)[0].scale(2) | |
for i in range(len(_qkvCaps))]).arrange(DOWN, buff=4).to_edge(LEFT*4) | |
self.play(Write(qkvWithCaps)) | |
self.wait() | |
# CREATING QUERY, KEY AND VALUE MATRICES (WITHOUT CAPS - LESS DETAILS) | |
# build query, key and value matrix without much details and only dimensions and name of matrix | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_qkvCaps = ['query', 'key', 'value'] | |
qkvWithoutCaps = VGroup(*[ buildMatrices(Rs = pos_encoded.shape[0], | |
Cs= pos_encoded.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
arrayToMatrix=pos_encoded, | |
clr= _colors[i], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=2, | |
hgt=10 | |
)[0].scale(2) | |
for i in range(len(_qkvCaps))]).arrange(DOWN, buff=8).to_edge(LEFT*6) | |
self.play(FadeTransformPieces(qkvWithCaps, qkvWithoutCaps)) | |
self.wait() | |
query, key, value = qkvWithoutCaps | |
# CREATING WEIGHT MATRICES TO PROJECT QUERY, KEY AND VALUE | |
# build query, key and value WEIGHTS matrix with only name and dimensions | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_qkvCaps = ['Query\:_{Weights}', 'Key\:_{Weights}', 'Value\:_{Weights}'] | |
uniform = t.distributions.uniform.Uniform(low=t.tensor(-1.0),high=t.tensor(1.0)) | |
weightsQKV = [buildMatrices(Rs =hiddenDimension, | |
Cs= hiddenDimension, | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = uniform.sample, | |
arrayToMatrix=None, | |
clr= _colors[i], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=2, | |
hgt=8, | |
seed=786+i+1 | |
) | |
for i in range(len(_qkvCaps))] | |
for i in range(len(weightsQKV)): | |
weightsQKV[i][0].scale(2).next_to(qkvWithoutCaps[i]) | |
self.play(Write(VGroup(*[weightsQKV[0][0], weightsQKV[1][0], weightsQKV[2][0]]))) | |
self.wait() | |
# PROJECTING QUERY, KEY AND VALUE USING WEIGHTS MATRICES (ALSO SOME ANIMATIONS) AND THEN PAINTING THEM IN DIFFERENT COLORS TO DIFFERENT HEADS WITHIN A SINGLE PROJECTED MATRIX | |
# Linear operations of key, query and value matrices | |
dimPerHead = 2 | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = ["d_%d"%i for i in range(embeddingDimension)] | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_multiColors = list(np.repeat([TEAL_E, YELLOW_E, ORANGE], dimPerHead)) | |
_qkvCaps = ['(Q)', '(K)', '(V)'] | |
_projQKV = [pos_encoded @ weightsQKV[0][1], pos_encoded @ weightsQKV[1][1], pos_encoded @ weightsQKV[2][1] ] | |
projQKV = [buildMatrices(Rs = _projQKV[i].shape[0], | |
Cs= _projQKV[i].shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_projQKV[i], | |
clr= _colors[i], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=2, | |
hgt=10 | |
) | |
for i in range(len(_projQKV))] | |
#- place projected matrices in their positions | |
for i in range(len(projQKV)): | |
projQKV[i][0].scale(2).next_to(weightsQKV[i][0], RIGHT*4) | |
#- transform to projected matrices | |
for i,w,p in zip(range(3), weightsQKV, projQKV): | |
matrices = VGroup(*[qkvWithoutCaps[i], w[0], p[0]]) | |
self.play(FadeTransformPieces(matrices[0:2], matrices[2])) | |
#- move projected matrices to left | |
moveAG = AnimationGroup(*[projQKV[i][0].animate.to_edge(LEFT*8) for i in range(len(projQKV))], lag_ratio=0.2) | |
self.play(moveAG) | |
self.wait() | |
# ------ different colors to columns to represent multiple heads within a single projected matrix | |
agOverall = [] | |
for i in range(len(projQKV)): | |
ls = [] | |
cols = projQKV[i][0][0].get_columns() | |
for c in range(len(cols)): | |
ls.append(cols[c].animate.set_color(_multiColors[c])) | |
agOverall.append(AnimationGroup(*ls, lag_ratio=0.75)) | |
agOverall = AnimationGroup(*agOverall, lag_ratio=0) | |
self.play(agOverall) | |
self.wait() | |
# SPLITTING QUERY MATRICES | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = ["d_%d"%i for i in range(embeddingDimension)] | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_qkvCaps = ['(Q_1)', '(Q_2)', '(Q_3)'] | |
_querySplit = t.stack(t.split(projQKV[0][1], dimPerHead, dim=-1))#.swapaxes(0,1) | |
querySplit = [buildMatrices(Rs = _querySplit[i].shape[0], | |
Cs= _querySplit[i].shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_querySplit[i], | |
clr= _multiColors[i*dimPerHead: (i+1)*dimPerHead], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames[i*dimPerHead: (i+1)*dimPerHead], | |
wdt=6, | |
hgt=10 | |
) | |
for i in range(len(_querySplit))] | |
querySplit[0][0].scale(2).next_to(projQKV[0][0], RIGHT*8) | |
self.play(FadeTransformPieces(projQKV[0][0][0].get_columns()[0:dimPerHead].copy(),querySplit[0][0]), run_time=2) | |
for i in range(1,len(querySplit)): | |
querySplit[i][0].scale(2).next_to(querySplit[i-1][0], RIGHT*4) | |
self.play(FadeTransformPieces(projQKV[0][0][0].get_columns()[(i) * dimPerHead : (i+1) * dimPerHead].copy(),querySplit[i][0]), run_time=2) | |
queryMatrices = VGroup(*[q[0] for q in querySplit]) | |
# SPLITTING KEY MATRICES | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = ["d_%d"%i for i in range(embeddingDimension)] | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_qkvCaps = ['(K_1)', '(K_2)', '(K_3)'] | |
_keySplit = t.stack(t.split(projQKV[1][1], dimPerHead, dim=-1))#.swapaxes(0,1) | |
keySplit = [buildMatrices(Rs = _keySplit[i].shape[0], | |
Cs= _keySplit[i].shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_keySplit[i], | |
clr= _multiColors[i*dimPerHead: (i+1)*dimPerHead], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames[i*dimPerHead: (i+1)*dimPerHead], | |
wdt=6, | |
hgt=10 | |
) | |
for i in range(len(_keySplit))] | |
keySplit[0][0].scale(2).next_to(projQKV[1][0], RIGHT*8) | |
self.play(FadeTransformPieces(projQKV[1][0][0].get_columns()[0:dimPerHead].copy(),keySplit[0][0]), run_time=2) | |
for i in range(1,len(keySplit)): | |
keySplit[i][0].scale(2).next_to(keySplit[i-1][0], RIGHT*4) | |
self.play(FadeTransformPieces(projQKV[1][0][0].get_columns()[(i) * dimPerHead : (i+1) * dimPerHead].copy(),keySplit[i][0]), run_time=2) | |
keyMatrices = VGroup(*[q[0] for q in keySplit]) | |
# TRANSPOSING SPLITTED KEY MATRICES | |
_lRowNames = ["d_%d "%i for i in range(embeddingDimension)] | |
_rRowNames = None | |
_colNames = None | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_qkvCaps = ['(K^T_{1})', '(K^T_{2})', '(K^T_{3})'] | |
_keyTSplit = t.transpose(t.stack(t.split(projQKV[1][1], dimPerHead, dim=-1)), dim0=-2, dim1=-1)#.swapaxes(0,1) | |
keyTSplit = [buildMatrices(Rs = _keyTSplit[i].shape[0], | |
Cs= _keyTSplit[i].shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_keyTSplit[i], | |
clr= _multiColors[i*dimPerHead], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames[i*dimPerHead: (i+1)*dimPerHead], | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=4, | |
hgt=18 | |
) | |
for i in range(len(_keyTSplit))] | |
keyTSplit[0][0].scale(2).next_to(keyMatrices, RIGHT*8) | |
self.play(FadeTransformPieces(keySplit[0][0],keyTSplit[0][0]), run_time=2) | |
for i in range(1,len(querySplit)): | |
keyTSplit[i][0].scale(2).next_to(keyTSplit[i-1][0], RIGHT*4) | |
self.play(FadeTransformPieces(keySplit[i][0],keyTSplit[i][0]), run_time=2) | |
keyTMatrices = VGroup(*[q[0] for q in keyTSplit]) | |
# ARRANGING THE SPLITTED QUERY AND KEY TRANSPOSED MATRICES CLOSER TO TRANSFORM THEM INTO ONE | |
self.play(queryMatrices.animate.arrange(DOWN, buff = 2).next_to(VGroup(projQKV[0][0],projQKV[1][0],projQKV[2][0]), RIGHT*16)) | |
for i in range(len(keyTMatrices)): | |
self.play(keyTMatrices[i].animate.next_to(queryMatrices[i], RIGHT*8)) | |
# CALCULATING ATTENTION HEADS | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_qkvCaps = ['Attention \: Head\:_{%d} \: (Q_%d \: . \: K^{T}_{%d})'%((i+1),(i+1),(i+1)) for i in range(len(querySplit))] | |
_attentionHeads = [querySplit[i][1] @ keyTSplit[i][1] for i in range(len(querySplit))] | |
attentionHeads = [buildMatrices(Rs = _attentionHeads[i].shape[0], | |
Cs= _attentionHeads[i].shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_attentionHeads[i], | |
clr= _multiColors[i*dimPerHead], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=4, | |
hgt=10 | |
) | |
for i in range(len(_attentionHeads))] | |
for i in range(len(attentionHeads)): | |
attentionHeads[i][0].scale(2).next_to(keyTMatrices[i], RIGHT*8) | |
self.play(Write(VGroup(*[attentionHeads[i][0][0].get_brackets(),attentionHeads[i][0][1],attentionHeads[i][0][2]])), run_time=0.25) | |
matrixMultiplicationAnimation(matrix_1= queryMatrices[i][0], matrix_2=keyTMatrices[i][0], result=attentionHeads[i][0][0]) | |
self.play(FadeOut(queryMatrices[i], keyTMatrices[i])) | |
# SCALING ATTENTION HEADS v1 | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_qkvCaps = ['Scaled\: Attention \: Head\:_{%d} \: \\frac{(Q_{%d}\: . \:K^{T}_{%d})}{\sqrt{d_k}}'%((i+1),(i+1),(i+1)) for i in range(len(querySplit))] | |
_attHeadsScaled = [t.round(attentionHeads[i][1], decimals=2).numpy().astype('str') for i in range(len(attentionHeads))] | |
_attHeadsScaledArray = [ np.vectorize(lambda x: x+'/\sqrt{d_{k}}')(_attHeadsScaled[i]) for i in range(len(attentionHeads))] | |
attHeadsScaled_v1 = [buildMatrices(Rs = _attHeadsScaledArray[i].shape[0], | |
Cs= _attHeadsScaledArray[i].shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_attHeadsScaledArray[i], | |
clr= _multiColors[i*dimPerHead], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=7, | |
hgt=12 | |
) | |
for i in range(len(_attHeadsScaledArray))] | |
for i in range(len(attHeadsScaled_v1)): | |
attHeadsScaled_v1[i] = (parenthesizeMatrix(matrix=attHeadsScaled_v1[i][0], encapsulateName='Scaled', bf=0.9), attHeadsScaled_v1[i][1]) | |
attHeadsScaled_v1[i][0].scale(2).set_x(attentionHeads[i][0].get_x()).set_y(attentionHeads[i][0].get_y()).set_z(attentionHeads[i][0].get_z()) | |
self.play(FadeTransformPieces(attentionHeads[i][0], attHeadsScaled_v1[i][0])) | |
# SCALING ATTENTION HEADS v2 | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_qkvCaps = ['Scaled \:Attention \: Head\:_{%d} \: \\frac{(Q_{%d}\: . \:K^{T}_{%d})}{\sqrt{d_k}}'%((i+1),(i+1),(i+1)) for i in range(len(querySplit))] | |
_attHeadsScaled = [t.round(attentionHeads[i][1], decimals=2).numpy().astype('str') for i in range(len(attentionHeads))] | |
_attHeadsScaledArray = [ np.vectorize(lambda x: x+'/\sqrt{%d}'%dimPerHead)(_attHeadsScaled[i]) for i in range(len(attentionHeads))] | |
attHeadsScaled_v2 = [buildMatrices(Rs = _attHeadsScaledArray[i].shape[0], | |
Cs= _attHeadsScaledArray[i].shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_attHeadsScaledArray[i], | |
clr= _multiColors[i*dimPerHead], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=7, | |
hgt=12 | |
) | |
for i in range(len(_attHeadsScaledArray))] | |
for i in range(len(attHeadsScaled_v2)): | |
attHeadsScaled_v2[i] = (parenthesizeMatrix(matrix=attHeadsScaled_v2[i][0], encapsulateName='Scaled', bf=0.9), attHeadsScaled_v2[i][1]) | |
attHeadsScaled_v2[i][0].scale(2).set_x(attHeadsScaled_v1[i][0].get_x()).set_y(attHeadsScaled_v1[i][0].get_y()).set_z(attHeadsScaled_v1[i][0].get_z()) | |
self.play(FadeTransformPieces(attHeadsScaled_v1[i][0], attHeadsScaled_v2[i][0])) | |
# SCALING RESULT OF MATMUL BLOCK v3 | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_qkvCaps = ['\\frac{(Q_{%d}\: . \:K^{T}_{%d})}{\sqrt{d_k}}'%((i+1),(i+1)) for i in range(len(querySplit))] | |
_attHeadsScaledArray = [ t.round(attentionHeads[i][1]/hiddenDimension, decimals=2) for i in range(len(attentionHeads))] | |
attHeadsScaled_v3 = [buildMatrices(Rs = _attHeadsScaledArray[i].shape[0], | |
Cs= _attHeadsScaledArray[i].shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_attHeadsScaledArray[i], | |
clr= _multiColors[i*dimPerHead], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=4, | |
hgt=10 | |
) | |
for i in range(len(_attHeadsScaledArray))] | |
for i in range(len(attHeadsScaled_v3)): | |
attHeadsScaled_v3[i] = (parenthesizeMatrix(matrix=attHeadsScaled_v3[i][0], encapsulateName='Scaled', bf=0.9), attHeadsScaled_v3[i][1]) | |
attHeadsScaled_v3[i][0].scale(2).set_x(attHeadsScaled_v2[i][0].get_x()).set_y(attHeadsScaled_v2[i][0].get_y()).set_z(attHeadsScaled_v2[i][0].get_z()) | |
self.play(FadeTransformPieces(attHeadsScaled_v2[i][0], attHeadsScaled_v3[i][0])) | |
attHeadsScaled= [buildMatrices(Rs = _attHeadsScaledArray[i].shape[0], | |
Cs= _attHeadsScaledArray[i].shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_attHeadsScaledArray[i], | |
clr= _multiColors[i*dimPerHead], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=4, | |
hgt=10 | |
) | |
for i in range(len(_attHeadsScaledArray))] | |
for i in range(len(attHeadsScaled_v3)): | |
attHeadsScaled[i][0].scale(2).set_x(attHeadsScaled_v3[i][0].get_x()).set_y(attHeadsScaled_v3[i][0].get_y()).set_z(attHeadsScaled_v3[i][0].get_z()) | |
self.play(FadeTransformPieces(attHeadsScaled_v3[i][0], attHeadsScaled[i][0])) | |
self.remove(attHeadsScaled_v1[0][0], attHeadsScaled_v2[0][0], attHeadsScaled_v3[0][0]) | |
del attHeadsScaled_v1 | |
del attHeadsScaled_v2 | |
del attHeadsScaled_v3 | |
# CONVERTING SCALED PROJECTED MATRICES TO SOFTMAX MATRICES | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_qkvCaps = ["Softmax \: \\biggl( \\frac{(Q_{%d}\: . \:K^{T}_{%d})}{\sqrt{d_k}} \\biggr)"%((i+1),(i+1)) for i in range(len(querySplit))] | |
scaledSoftmax = [t.softmax(attHeadsScaled[i][1], dim=-1) for i in range(len(attHeadsScaled))] | |
scaledQKSoftmax = [buildMatrices(Rs = scaledSoftmax[i].shape[0], | |
Cs= scaledSoftmax[i].shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=scaledSoftmax[i], | |
clr= _multiColors[i*dimPerHead], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=4, | |
hgt=10 | |
) | |
for i in range(len(scaledSoftmax))] | |
for i in range(len(scaledQKSoftmax)): | |
parenSoft = (parenthesizeMatrix(matrix=attHeadsScaled[i][0], encapsulateName='Softmax', bf=0.9, initial_scale_factor=2), attHeadsScaled[i][1]) | |
parenSoft[0].set_x(attHeadsScaled[i][0].get_x()).set_y(attHeadsScaled[i][0].get_y()).set_z(attHeadsScaled[i][0].get_z()) | |
self.play(FadeTransformPieces(attHeadsScaled[i][0], parenSoft[0])) | |
self.wait() | |
scaledQKSoftmax[i][0].scale(2).set_x(parenSoft[0].get_x()).set_y(parenSoft[0].get_y()).set_z(parenSoft[0].get_z()) | |
self.play(FadeTransformPieces(parenSoft[0], scaledQKSoftmax[i][0])) | |
self.wait() | |
self.remove(parenSoft[0]) | |
del parenSoft | |
tempscaledQKSoftmax = VGroup(*[ scaledQKSoftmax[i][0] for i in range(len(scaledQKSoftmax))]) | |
self.play(tempscaledQKSoftmax.animate.arrange(DOWN, buff=4).next_to(VGroup(*[projQKV[0][0], projQKV[1][0]]), RIGHT*48)) | |
# SPLITTING VALUE MATRICES | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = ["d_%d"%i for i in range(embeddingDimension)] | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_qkvCaps = ['(V_1)', '(V_2)', '(V_3)'] | |
_valueSplit = t.stack(t.split(projQKV[2][1], dimPerHead, dim=-1)) | |
valueSplit = [buildMatrices(Rs = _valueSplit[i].shape[0], | |
Cs= _valueSplit[i].shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_valueSplit[i], | |
clr= _multiColors[i*dimPerHead: (i+1)*dimPerHead], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames[i*dimPerHead: (i+1)*dimPerHead], | |
wdt=6, | |
hgt=10 | |
) | |
for i in range(len(_valueSplit))] | |
valueSplit[0][0].scale(2).next_to(projQKV[2][0], RIGHT*8) | |
self.play(FadeTransformPieces(projQKV[2][0][0].get_columns()[0:dimPerHead].copy(),valueSplit[0][0]), run_time=2) | |
for i in range(1,len(valueSplit)): | |
valueSplit[i][0].scale(2).next_to(valueSplit[i-1][0], RIGHT*4) | |
self.play(FadeTransformPieces(projQKV[2][0][0].get_columns()[(i) * dimPerHead : (i+1) * dimPerHead].copy(),valueSplit[i][0]), run_time=2) | |
valueMatrices = VGroup(*[q[0] for q in valueSplit]) | |
for i in range(len(valueMatrices)): | |
self.play(valueMatrices[i].animate.next_to(scaledQKSoftmax[i][0], RIGHT*8)) | |
self.play(tempscaledQKSoftmax.animate.arrange(DOWN, buff=7).next_to(VGroup(*[projQKV[0][0], projQKV[1][0]]), RIGHT*48)) | |
self.play(tempscaledQKSoftmax.animate.shift(DOWN*10)) | |
for i in range(len(valueMatrices)): | |
self.play(valueMatrices[i].animate.next_to(scaledQKSoftmax[i][0], RIGHT*8)) | |
# CALCULATING SOFTMAX RESULTS WITH VALUE MATRICES | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = ["d_%d"%i for i in range(embeddingDimension)] | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_qkvCaps = ["Softmax \: \\biggl( \\frac{(Q_{%d}\: . \:K^{T}_{%d})}{\sqrt{d_k}} \\biggr) \: . \: V_{%d}"%((i+1),(i+1),(i+1)) for i in range(len(querySplit))] | |
_sftDotV = [scaledQKSoftmax[i][1] @ valueSplit[i][1] for i in range(len(scaledQKSoftmax))] | |
sftDotV = [buildMatrices(Rs = _sftDotV[i].shape[0], | |
Cs= _sftDotV[i].shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_sftDotV[i], | |
clr= _multiColors[i*dimPerHead], | |
cap= _qkvCaps[i], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames[i*dimPerHead: (i+1)*dimPerHead], | |
wdt=6, | |
hgt=16 | |
) | |
for i in range(len(_sftDotV))] | |
for i in range(len(sftDotV)): | |
sftDotV[i][0].scale(2).next_to(valueMatrices[i], RIGHT*8) | |
self.play(Write(VGroup(*[sftDotV[i][0][0].get_brackets(),sftDotV[i][0][1],sftDotV[i][0][2]])), run_time=0.25) | |
matrixMultiplicationAnimation(matrix_1= scaledQKSoftmax[i][0][0], matrix_2=valueMatrices[i][0], result=sftDotV[i][0][0]) | |
self.play(FadeOut(scaledQKSoftmax[i][0], valueMatrices[i])) | |
tempsftDotVT = VGroup(*[sftDotV[i][0] for i in range(len(sftDotV))]) | |
self.play(tempsftDotVT.animate.arrange(RIGHT, buff=4).move_to(ORIGIN)) | |
self.wait() | |
# CONCATENATED RESULT | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = ["d_%d "%i for i in range(embeddingDimension)] | |
_colors = [RED_E, GREEN_E, BLUE_E] | |
_qkvCaps = ["Softmax \: \\biggl( \\frac{(Q \: . \:K^{T})}{\sqrt{d_k}} \\biggr) \: . \: V"] | |
_concatSftDotV = t.hstack(_sftDotV) | |
concatSftDotV = buildMatrices(Rs = _concatSftDotV.shape[0], | |
Cs= _concatSftDotV.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_concatSftDotV, | |
clr= BLUE_E, | |
cap= _qkvCaps[0], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=2, | |
hgt=10 | |
) | |
concatSftDotV[0].scale(2).next_to(VGroup(*[projQKV[0][0], projQKV[1][0], projQKV[2][0]]), RIGHT*32) | |
self.play(FadeTransform(tempsftDotVT, concatSftDotV[0])) | |
# ------ different colors to columns to represent multiple heads within a single projected matrix | |
_multiColors = list(np.repeat([TEAL_E, YELLOW_E, ORANGE], dimPerHead)) | |
agOverall = [] | |
cols = concatSftDotV[0][0].get_columns() | |
for c in range(len(cols)): | |
agOverall.append(cols[c].animate.set_color(_multiColors[c])) | |
agOverall = AnimationGroup(*agOverall, lag_ratio=0.5) | |
self.play(agOverall) | |
self.wait() | |
# CREATING CONCATENATION WEIGHTS | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = None | |
_qkvCaps = ['Concatenation \: ( \: or \: Output \: Weights \:)'] | |
uniform = t.distributions.uniform.Uniform(low=t.tensor(-1.0),high=t.tensor(1.0)) | |
concatWeights = buildMatrices(Rs =hiddenDimension, | |
Cs= hiddenDimension, | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = uniform.sample, | |
arrayToMatrix=None, | |
clr= GOLD_E, | |
cap= _qkvCaps[0], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=2, | |
hgt=8, | |
seed=786+i+1 | |
) | |
concatWeights[0].scale(2).next_to(concatSftDotV[0], RIGHT*8) | |
self.play(Write(concatWeights[0])) | |
self.wait() | |
# FINAL ENCODED MATRIX MULTIPLICATION | |
_lRowNames = None | |
_rRowNames = None | |
_colNames = ["d_%d"%i for i in range(embeddingDimension)] | |
_qkvCaps = ['Final \: Encoder \: Result'] | |
_finalResult = concatSftDotV[1] @ concatWeights[1] | |
finalResult = buildMatrices(Rs = _finalResult.shape[0], | |
Cs= _finalResult.shape[1], | |
roundDecimals=2, | |
whereTo=None, | |
referenceMobj=None, | |
func = None, | |
arrayToMatrix=_finalResult, | |
clr= YELLOW, | |
cap= _qkvCaps[0], | |
lRowNames= _lRowNames, | |
rRowNames= _rRowNames, | |
colNames=_colNames, | |
wdt=2, | |
hgt=10 | |
) | |
finalResult[0].scale(2).next_to(concatWeights[0]) | |
self.play(Write(VGroup(*[finalResult[0][0].get_brackets(), finalResult[0][1], finalResult[0][2]])), run_time=0.25) | |
matrixMultiplicationAnimation(matrix_1= concatSftDotV[0][0], matrix_2=concatWeights[0][0], result=finalResult[0][0]) | |
self.play(FadeOut(concatSftDotV[0], concatWeights[0])) | |
self.play(finalResult[0].animate.move_to(ORIGIN)) | |
self.play(finalResult[0].animate.scale(2.5)) | |
self.play(Indicate(finalResult[0]), run_time= 5) | |
self.wait(15) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment