Last active
August 8, 2023 02:17
-
-
Save arif9799/318c7ab4f478c065b0b9b181af917254 to your computer and use it in GitHub Desktop.
Attention is not enough! Recurrent Neural Nets with Luong's Dot Product Attention Worked-out Example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from manim import * | |
from manim.utils.unit import Percent, Pixels | |
from colour import Color | |
import gensim | |
from gensim.models import Word2Vec | |
import numpy as np | |
import torch as t | |
config.frame_width = 146 | |
config.frame_height = 85.5 | |
HIGH = True | |
if HIGH: | |
config.pixel_height = 2160 | |
config.pixel_width = 3840 | |
else: | |
config.pixel_height = 1080 | |
config.pixel_width = 1920 | |
#************************************************************************************************************************************************************************************ | |
# STRUCTURE OF THIS CODE FILE | |
#---- def construct_Word_Embeddings() : Function to create Word Embeddings or Vectors for animations using gensim | |
#---- def getCell() : function to build rectangle shapes that portrays the encoder decoder cells | |
#---- def getArrow() : function that builds arrows to and from Mobjects indicating flow | |
#---- def buildMatrices() : function that takes in appropriate input (tensors or functions to generate tensors) and returns a full packaged matrix with dimension and caption as a mobject | |
#---- def ellipsisDenseLayer() : static function that builds a Dense Layer Classifier components VGroup of neurons and weights that can be simply written on top of a vector to indicate classification | |
#---- Class FIgure8(MovingCamderaScene) : the main block of rnn animation code | |
#---- def construct(): Function that Scene class calls automatically to render | |
#---- def myAnimation(): COMPLEX CONCEPTS COMPREHENSBLE Animation | |
#---- def writeUnwrite(): Simulates the effect of show creation and then Fade Out | |
#---- def calcAndBuild(): Function that does matrix multiplication calculation of weights, indicates the mobjects of interest and transforms them or their copies into results | |
#---- def parenthesizeMatrix(): creates bracket around mobject to represent as function of tangent, softmax or any other | |
#---- def matrixMultiplicationAnimation(): matrix multiplication animation takes two matrix mobjects and displays detailed operations of corresponding rows and columns | |
#---- def payAttention(): Luongs' Attention implemented from scratch and displaying all nitty-gritty details of calculation | |
#************************************************************************************************************************************************************************************ | |
################################################################################ | |
# Creating Word Vectors # | |
################################################################################ | |
def construct_Word_Embeddings(sentences, dim): | |
data = sentences | |
dimension = dim | |
model = Word2Vec(data, min_count = 1,vector_size = dimension, window = 3) | |
sequences = [] | |
for seq in data: | |
sequence = [] | |
for word in seq: | |
word_emb = model.wv.get_vector(word, True).reshape(1,-1) | |
word_emb = t.from_numpy(word_emb) | |
sequence.append(word_emb) | |
sequences.append(t.round(t.cat(sequence, dim=0), decimals=2)) | |
return sequences | |
################################################################################ | |
# FUNCTION TO CREATE ENCODER CELLS # | |
################################################################################ | |
def getCell(w, h, c, fo = 0.1): | |
""" | |
w: width of rectangle cell | |
h: height of rectangle cell | |
fo: fill opacity of the rectangle cell | |
txt of cell | |
""" | |
return Rectangle(height= h, width= w, color= c, fill_opacity= fo) | |
################################################################################ | |
# FUNCTION TO CREATE ALL ARROW TYPES # | |
################################################################################ | |
def getArrow(arrowType, cell, sWidth, arrowLen, initialMobj = None, bf = 0, rt = 0.1): | |
""" | |
arrowType: What type of arrow is it, Going into main mobject ('input'), coming out of main mobject ('output') or between two mobjects ('hidden') | |
cell: the maain mobject rectangle cell or the mobject on right | |
sWidth: Stroke width or thickness of the arrow | |
arrowLen: Length of the Arrow | |
initialMobj: Mobject on the left from where the arrow starts to grow | |
bf: buff | |
rt: max_tip_length_to_length_ratio | |
""" | |
arr = Arrow(stroke_width=sWidth, buff=bf, max_tip_length_to_length_ratio=rt) | |
if arrowType == 'input': | |
arr.add_updater( | |
lambda x,y=cell: x.become( | |
Arrow(start = y.get_bottom() + DOWN * arrowLen, end = y.get_bottom(), stroke_width=sWidth, buff=bf, max_tip_length_to_length_ratio=rt))) | |
elif arrowType == 'output': | |
arr.add_updater( | |
lambda x,y=cell: x.become( | |
Arrow(start = y.get_top(), end = y.get_top() + UP * arrowLen, stroke_width=sWidth, buff=bf, max_tip_length_to_length_ratio=rt))) | |
elif arrowType == 'hidden': | |
arr.add_updater( | |
lambda x,i = initialMobj, f = cell: x.become( | |
Arrow(start = i.get_right(), end = f.get_left(), stroke_width=sWidth, buff=bf, max_tip_length_to_length_ratio=rt))) | |
else: | |
arr = None | |
return arr | |
################################################################################ | |
# FUNCTION TO BUILD MATRICES WITH DIMENSIONS AND CAPTIONS # | |
################################################################################ | |
def buildMatrices(Rs, Cs, roundDecimals, whereTo = None, referenceMobj = None, arrayToMatrix = None, func = None, clr = WHITE, cap = None, hgt = 6, wdt = 1.5, seed = 786): | |
""" | |
Rs: number of rows of matrix to build | |
Cs: number of cols of matrix to build | |
roundDecimals: number of places to round off the values to | |
whereTo: Where to exactly place the matrix near the reference mobject (DOWN, UP, LEFT or RIGHT) | |
referenceMobj: The Mobject w.r.t which the matrix is to be placed on screen and add updater of it | |
arrayToMatrix: The array whose matrix Mobject is to be built | |
func: if not an array (if arrayToMatrix is None), the function of torch using which array is to be generated to build matrix mobject | |
hgt: height of Matrix | |
wdt: width of Matrix | |
seed: the value to maintain stability of randomness | |
cap: Caption or name of the matrix | |
clr: color of the matrix | |
""" | |
if arrayToMatrix is None and func is None: | |
print("No Input to process and build a Matrix upon") | |
return None | |
# ------------> Set a random seed for weights initialization and build a matrix | |
t.manual_seed(seed=seed) | |
ini = t.round(func((Rs, Cs)), decimals=roundDecimals) if (arrayToMatrix is None) else t.round(arrayToMatrix,decimals= roundDecimals) | |
iniMat = Matrix(ini.numpy()) | |
iniMat.height = hgt | |
iniMat.width = wdt * Cs | |
iniMat.color = clr | |
# ------------> locate the matrix somewhere w.r.t reference mobject | |
if referenceMobj is None: | |
pass | |
else: | |
iniMat.next_to(referenceMobj, whereTo) | |
iniMat.add_updater( lambda x, y = referenceMobj: x.next_to(y, whereTo)) | |
# ------------> Create the MASTER MATRIX, the VGroup and add the main matrix along with dimension and caption as text | |
iniMatrix = VGroup() | |
iniMatrix.add(iniMat) | |
dimText = MathTex('(%d, %d)'%(Rs, Cs), font_size = 40).move_to(iniMat.get_critical_point(DR)+DOWN/2) | |
dimText.add_updater(lambda x, y = iniMat: x.move_to(y.get_critical_point(DR)+DOWN/2)) | |
iniMatrix.add(dimText) | |
if cap is None: | |
pass | |
else: | |
capText = MathTex('%s'%(cap), font_size = 64).move_to(iniMat.get_critical_point(DOWN) + DOWN*1.25 ) | |
capText.add_updater(lambda x, y = iniMat: x.move_to(y.get_critical_point(DOWN) + DOWN*1.25 )) | |
iniMatrix.add(capText) | |
return iniMatrix, ini | |
################################################################################ | |
# FUNCTION TO BUILD DENSE LAYER DEMO FOR OUTPUT # | |
################################################################################ | |
def ellipsisDenseLayer(n, r, sW, c=GOLD): | |
""" | |
ref_vec: mobject from where the transition occurs | |
n: number of mobjects to be added in ellipsis formation | |
r: radius of neurons of dense layer | |
sW: width of arrow transition | |
""" | |
tipLengthRatio = 0.05 | |
eDL = VGroup() | |
eDL.add(*[Circle(radius= r, color=c) for i in range(n-1)]) | |
eDL.add(*[Dot(color=c) for i in range(3)]) | |
eDL.add(Circle(radius=r, color=c)) | |
eDL.arrange(RIGHT, buff=0.15) | |
softmaxProbs = Text("Softmax Probabilities Vector", slant=OBLIQUE, font_size=48, color=DARK_BROWN).next_to(eDL, UP*10) | |
ar = VGroup() | |
arrToDenseStartPosition =eDL.get_bottom() + DOWN*4 | |
denseToTextEndPosition =softmaxProbs.get_bottom() | |
ar.add(*[ Arrow(start= arrToDenseStartPosition, end= eDL[i].get_bottom(), stroke_width=sW, max_tip_length_to_length_ratio=tipLengthRatio) for i in range(n-1)]) | |
ar.add(Arrow(start= arrToDenseStartPosition, end= eDL[-1].get_bottom(), stroke_width=sW, max_tip_length_to_length_ratio=tipLengthRatio)) | |
ar.add(*[ Arrow(start= eDL[i].get_top(), end= denseToTextEndPosition, stroke_width=sW, max_tip_length_to_length_ratio=tipLengthRatio) for i in range(n-1)]) | |
ar.add(Arrow(start= eDL[-1].get_top(), end= denseToTextEndPosition, stroke_width=sW, max_tip_length_to_length_ratio=tipLengthRatio)) | |
eDL.add(Text("Classifier",slant=OBLIQUE, font_size=48, color=DARK_BLUE).next_to(eDL[-1], RIGHT*2)) | |
eDL+=softmaxProbs | |
eDL+=ar | |
return eDL | |
############################################################################################################################################################## | |
# |-----| |\ | |\ | # | |
# | | | \ | | \ | # | |
# |-----| | \ | | \ | # | |
# |\ | \ | | \ | # | |
# | \ | \| | \| # | |
############################################################################################################################################################## | |
class Figure8(MovingCameraScene): | |
def construct(self): | |
################################################################################ | |
# INTRO ANIMATION SAME ACROSS ALL VIDEOS # | |
################################################################################ | |
def myAnimation(wordsForIntro: str): | |
fontHeight = config.frame_height//8 | |
fontColor = WHITE | |
timePerChar = 0.1 | |
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3) | |
self.play(Broadcast(C), run_time=1) | |
self.add(C) | |
# Building text objects of individual characters. | |
wordsMath = VGroup() | |
for word in wordsForIntro: | |
charTex = VGroup() | |
for i,ch in enumerate(word): | |
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight) | |
if i != 0: | |
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN) | |
else: | |
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN) | |
charTex.add(chTex) | |
wordsMath.add(charTex) | |
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION | |
for wInd in range(len(wordsMath)): | |
for chInd in range(len(wordsMath[wInd])): | |
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.wait(0.5) | |
for chInd in reversed(range(len(wordsMath[wInd]))): | |
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//3, run_time=1.5)) | |
self.play(ShrinkToCenter(C, run_time=0.25)) | |
self.wait(0.5) | |
myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE']) | |
################################################################################ | |
# ALTERNATIVE FUNCTION TO CREATE/UNCREATE OBJECT # | |
################################################################################ | |
def writeUnwrite(mob,runtime): | |
""" | |
mob: Mobject you want to create and remove immediately | |
runtime: duration of the effect | |
""" | |
mob.save_state() | |
self.play(Write(mob), run_time=runtime) | |
self.wait(0.25) | |
self.play(Unwrite(mob), run_time=runtime) | |
mob.restore() | |
################################################################################ | |
# VECTOR CALCULATION AND MOBJECT BUILD # | |
################################################################################ | |
def calcAndBuild(op, m1, m2, rows, vecHeight, vecWidth, refMob, position, toIndicate, colr, shiftBy=[0,0,0], cp= None): | |
if op == t.matmul: | |
v = op(m1, m2) | |
initial = VGroup(*toIndicate).copy() | |
elif op == t.add: | |
v = op(m1, m2) | |
initial = VGroup(*toIndicate) | |
elif m2==None: | |
v = op(m1) | |
initial = VGroup(*toIndicate) | |
V, _ = buildMatrices(Rs= rows, Cs= 1, roundDecimals=2, whereTo= None, referenceMobj= None, arrayToMatrix= v.reshape(-1,1), clr = colr, hgt= vecHeight, wdt= vecWidth, cap=cp) | |
V.move_to(refMob, position).shift(shiftBy) | |
ag = AnimationGroup(*[Indicate(o) for o in toIndicate], lag_ratio=0) | |
self.play(ag, run_time=0.75) | |
self.play(FadeTransformPieces(initial, V), run_time=1.25) | |
return v, V | |
################################################################################ | |
# FUNCTION TO PARENTHESISE MATRICES # | |
################################################################################ | |
def parenthesizeMatrix(matrix, encapsulateName: str, referenceMobject: Mobject = None, position = None, bf=0.5, initial_scale_factor: float = 1.5): | |
""" | |
matrix: VGroup that contains matrix, dimension and caption | |
encapsulationName: Name of function in which matrix needs to be presented like sine, cosine, etc | |
cl: cell in which the parenthesized matrix needs to fit in | |
""" | |
parenMat = VGroup() | |
parens = MathTex("(", ")") | |
parens.scale(initial_scale_factor) | |
parens.stretch_to_fit_height(matrix.height) | |
mt= matrix[0] | |
l_paren, r_paren = parens.split() | |
l_paren.next_to(mt, LEFT, buff=bf) | |
r_paren.next_to(mt, RIGHT, buff=bf) | |
tx = MathTex(encapsulateName) | |
tx.scale(initial_scale_factor) | |
tx.next_to(l_paren, LEFT, buff=bf) | |
parenMat+=matrix | |
parenMat.add(*[l_paren, r_paren, tx]) | |
self.play(FadeTransformPieces(matrix, parenMat), run_time=0.75) | |
if referenceMobject is not None: | |
self.play(parenMat.animate.move_to(referenceMobject, position)) | |
return parenMat | |
################################################################################ | |
# FUNCTION TO ANIMATE MATRIX MULTIPLICATION # | |
################################################################################ | |
def matrixMultiplicationAnimation(matrix_1: Matrix, matrix_2: Matrix, result: Matrix): | |
matrices = VGroup(*[matrix_1, matrix_2, result]) | |
self.play(self.camera.frame.animate.move_to(matrices).set(width = matrices.width * 5, height = matrices.height*7)) | |
M1_rows = matrix_1.get_rows() | |
M2_cols = matrix_2.get_columns() | |
M3_rows = result.get_rows() | |
for rw_index in range(0, len(M1_rows)): | |
for cl_index in range(0, len(M2_cols)): | |
ag_final = AnimationGroup(*[Indicate(M1_rows[rw_index]), | |
Indicate(M2_cols[cl_index]), | |
FadeTransformPieces(mobject=VGroup(*[M1_rows[rw_index], M2_cols[cl_index]]).copy(), target_mobject=M3_rows[rw_index][cl_index]) | |
], lag_ratio=0) | |
self.play(ag_final, run_time=0.6) | |
################################################################################ | |
# FUNCTION TO ANIMATE ATTENTION CALCULATION # | |
################################################################################ | |
def payAttention(eHVForCalc, EHV, curHidStateForCalc, curHidState, contextWeightsForCalc, contextWeightsVec, referenceMobjectFinal, finalCap, colr): | |
""" | |
################################################################################################ | |
# THIS FUNCTION IMPLEMENTS LUONG'S ATTENTION # | |
# FIRST CALCULATE CURRENT HIDDEN STATE OF DECODER # | |
# USE CURRENT HIDDEN STATE TO CALCULATE INNER PRODUCTS WITH ENCODER HIDDEN STATES # | |
# POST INNER PRODUCTS YOU GET ALPHA WEIGHTS # | |
# SOFTMAX THE ALPHA WEIGHTS TO BUILD ATTENTION SCORES # | |
################################################################################################ | |
eHVForCalc: encoder hidden state for calculation | |
EHV: Encoded Hidden Vectors | |
curHidStateForCalc: current hidden state for calculation | |
curHidState: Current Hidden State | |
contextWeights: the weights associated with contextualized vector (combiination of current context vector with its current decoder state) | |
contextWeightsVec: Vectorized form of contextWeights | |
referenceMobjectFinal: the mobeject with whose refernce the final output is to be placed | |
""" | |
toEraseMobjectsList = VGroup() | |
equation = VGroup() | |
finalTransformMobjects = VGroup() | |
# Creating a copy of Encoder Hidden Vectors, Animate & bring it upwards to transform into a unified Matrix | |
EHVCopy = EHV.copy().clear_updaters().next_to(EHV, UP*15) | |
self.play(FadeTransformPieces(EHV.copy(), EHVCopy), run_time=0.5) | |
finalTransformMobjects.add(EHVCopy) | |
# self.camera.frame.save_state() | |
self.play(self.camera.frame.animate.move_to(EHVCopy).set(width=EHVCopy.width*7, height=EHVCopy.height*10)) | |
# Construct and transform individual vectors to unified matrix for Matrix Multiplication Animation | |
encodedMatrixVecForCalc = t.transpose(input = t.hstack(eHVForCalc), dim0=0, dim1=1) | |
encodedMatrixVec, _ = buildMatrices(Rs=encodedMatrixVecForCalc.shape[0], Cs=encodedMatrixVecForCalc.shape[1], roundDecimals=2, whereTo=UP*15, referenceMobj=EHVCopy, arrayToMatrix=encodedMatrixVecForCalc, func=None, clr = DARK_BLUE, cap="Encoded \: States \: Matrix") | |
self.play(TransformMatchingShapes(EHVCopy.copy(), encodedMatrixVec), run_time=2) | |
equation.add(encodedMatrixVec) | |
# Create and write the dot for Matrix Multiplication Animation | |
dot = Dot(radius=0.15).next_to(encodedMatrixVec, RIGHT * 5) | |
self.play(Write(dot), run_time=0.25) | |
self.wait(0.25) | |
equation.add(dot) | |
# Make a copy of current hidden state next to its output arrow that can be used to concatenate in contextualized vector later on | |
cVec1 = curHidState.copy().clear_updaters().next_to(referenceMobjectFinal, RIGHT) | |
self.play(self.camera.frame.animate.move_to(cVec1).set(width=cVec1.width*5, height=cVec1.height*7)) | |
self.play(FadeTransformPieces(curHidState.copy(), cVec1), run_time=1) | |
self.wait(0.25) | |
self.play(Restore(self.camera.frame)) | |
toEraseMobjectsList.add(cVec1) | |
self.camera.frame.save_state() | |
# Bring cV1 copy closer, Build transposed vector of current hidden state for align function (scores calculation) and transform | |
cVec1Copy = cVec1.copy().clear_updaters().next_to(dot, RIGHT*5) | |
self.play(FadeTransformPieces(cVec1.copy(), cVec1Copy), run_time=0.5) | |
equation.add(cVec1Copy) | |
# Write Equals to sign | |
equals = MathTex('=', font_size = 192).next_to(cVec1Copy, RIGHT*5) | |
self.play(Write(equals), run_time=0.25) | |
equation.add(equals) | |
# calculate result, write brackets of result matrix and call animation function | |
r = encodedMatrixVecForCalc @ curHidStateForCalc.reshape(-1,1) | |
R,_ = buildMatrices(Rs=encodedMatrixVecForCalc.shape[0], Cs=1, roundDecimals=2, arrayToMatrix=r.reshape(-1,1), cap="Alignment \: Scores", hgt=encodedMatrixVec.height, clr=colr) | |
R.next_to(equals, RIGHT*5) | |
self.play(Write(VGroup(*[R[0].get_brackets(),R[1],R[2]])), run_time=0.25) | |
matrixMultiplicationAnimation(matrix_1= encodedMatrixVec[0], matrix_2= cVec1Copy[0], result= R[0]) | |
toEraseMobjectsList.add(R) | |
# Remove LHS of matrix multiplication performed | |
self.play(Unwrite(equation), run_time=0.25) | |
self.play(R.animate.next_to(EHVCopy, UP*15)) | |
self.remove(equation) | |
self.wait(0.5) | |
# parenthesize to show softmax | |
self.play(self.camera.frame.animate.move_to(EHVCopy).set(width=EHVCopy.width*7, height=EHVCopy.height*10)) | |
pMat = parenthesizeMatrix(matrix=R, encapsulateName='Softmax', bf=0.75) | |
self.wait(0.5) | |
toEraseMobjectsList.add(pMat) | |
# convert to softmax values | |
rSoftmax = t.softmax(r.reshape(-1,1), dim=0) | |
rSoftmaxVec,_ = buildMatrices(Rs= encodedMatrixVecForCalc.shape[0], Cs = 1, roundDecimals=2, arrayToMatrix= rSoftmax.reshape(-1,1), hgt=R.height, cap="Attention \: Scores", clr = colr) | |
rSoftmaxVec.next_to(EHVCopy, UP*15).clear_updaters() | |
self.play(FadeTransformPieces(pMat, rSoftmaxVec), run_time= 0.5) | |
self.wait(0.5) | |
toEraseMobjectsList.add(rSoftmaxVec) | |
# Distribute attention values to its corresponding encoder hidden states | |
agAttentionScores = VGroup(*[rSoftmaxVec[0].get_columns()[0][i].copy().next_to(EHVCopy[i][0], LEFT*4).scale(1.5) for i in range(len(EHVCopy))]) | |
plusSigns = VGroup(*[MathTex("+", font_size=192).next_to(EHVCopy[i][0], RIGHT*12) for i in range(len(EHVCopy)-1)] ) | |
self.play(AnimationGroup(*[FadeTransformPieces(rSoftmaxVec, agAttentionScores), Write(plusSigns)], lag_ratio=0)) | |
self.wait(0.5) | |
finalTransformMobjects.add(agAttentionScores) | |
finalTransformMobjects.add(plusSigns) | |
# Finally Transform all trivial components and build context vector | |
cV2 = rSoftmax.reshape(1,-1) @ encodedMatrixVecForCalc | |
cVec2, _ = buildMatrices(Rs = encodedMatrixVecForCalc.shape[1], Cs=1, roundDecimals=2, arrayToMatrix=cV2.reshape(-1,1),hgt=rSoftmaxVec.height, clr=colr, cap="Context \: Vector") | |
cVec2.next_to(finalTransformMobjects, UP*15) | |
self.play(FadeTransformPieces(finalTransformMobjects, cVec2)) | |
self.wait(0.5) | |
toEraseMobjectsList.add(cVec2) | |
# Bring Decoder hidden state closer to context vector | |
self.play(self.camera.frame.animate.move_to(cVec1).set(width=cVec1.width*3, height=cVec1.height*5)) | |
self.wait(0.5) | |
self.play(Restore(self.camera.frame)) | |
self.play(cVec1.animate.next_to(cVec2, DOWN*2)) | |
self.camera.frame.save_state() | |
self.play(self.camera.frame.animate.move_to(EHVCopy).set(width=EHVCopy.width*7, height=EHVCopy.height*10)) | |
self.wait(0.5) | |
# Convert to contextualized vector | |
contextualized = t.concat((cV2.reshape(-1,1), curHidStateForCalc.reshape(-1,1)), dim=0) | |
contextualizedVec, _ = buildMatrices(Rs = contextualized.reshape(-1,1).shape[0], Cs=1, roundDecimals=2, arrayToMatrix=contextualized.reshape(-1,1), clr=colr, cap="Contextualized") | |
contextualizedVec.next_to(cVec2, RIGHT).shift(DOWN*3) | |
self.play(FadeTransformPieces(VGroup(*[cVec1, cVec2]), contextualizedVec)) | |
self.wait(0.4) | |
toEraseMobjectsList.add(contextualizedVec) | |
# Build Attentional hidden state using context weights | |
contextWeightsVec.save_state() | |
self.play(contextWeightsVec.animate.next_to(contextualizedVec, LEFT*3)) | |
self.wait(0.5) | |
attentionalHidState = contextWeightsForCalc @ contextualized.reshape(-1,1) | |
attentionalHidStateVec, _ = buildMatrices(Rs = attentionalHidState.reshape(-1,1).shape[0], Cs=1, roundDecimals=2, arrayToMatrix=attentionalHidState.reshape(-1,1), clr=colr, cap= finalCap) | |
attentionalHidStateVec.next_to(contextualizedVec, RIGHT*3) | |
self.play(FadeTransformPieces(VGroup(*[contextWeightsVec, contextualizedVec]).copy(), attentionalHidStateVec), run_time=1) | |
self.wait(0.5) | |
self.play(Restore(self.camera.frame), run_time=0.5) | |
self.play(FadeOut(contextualizedVec), run_time=0.5) | |
self.play(contextWeightsVec.animate.restore(), run_time=0.5) | |
# remove mobjects not on screen anymore | |
self.remove(*[toEraseMobjectsList, equation, finalTransformMobjects]) | |
# put decoder attentional hidden state near to arrow and return variables | |
self.play(attentionalHidStateVec.animate.next_to(referenceMobjectFinal, RIGHT)) | |
self.wait(0.5) | |
return rSoftmax, attentionalHidState, attentionalHidStateVec | |
#****************************************************************************************************************************************************************** | |
# MAIN LOGIC STARTS HERE!!!!!!!! # | |
#****************************************************************************************************************************************************************** | |
# ---------------> TITLE OF VIDEO | |
title = Title('Vanilla RNN with Attention', match_underline_width_to_text=True, underline_buff= MED_LARGE_BUFF, font_size = 300).shift(DOWN*7) | |
subText = MathTex('Luong\'s \: Global \: attention \: using \: Dot \: Scores (h^{T}_{t},\:h_{s}) \: and \: tying \: weights', font_size = 150).next_to(title, DOWN*2) | |
# ---------------> VARIABLES BELOW ARE TENSORS | |
embeddingDimensions = 3 | |
data = [['Attention', 'is', 'not', 'enough'], ['Towards', 'Data', 'Science'], ['It', 'was', 'a', 'really', 'good', 'article']] | |
sentences_Embedded = construct_Word_Embeddings(data, embeddingDimensions) | |
startEmbedding = construct_Word_Embeddings([['<start>']], embeddingDimensions) | |
# ------------> IMPORTANT GROUPS | |
WEIGHTS = VGroup() | |
Cells = VGroup() | |
InputArrows = VGroup() | |
OutputArrows = VGroup() | |
HiddenArrows = VGroup() | |
ENCInputVectors = VGroup() | |
HiddenVectors = VGroup() | |
DECInputVectors = VGroup() | |
DECOutputVectors = VGroup() | |
denseLayer = ellipsisDenseLayer(n=5, r=0.2, sW=2) | |
# ------------> IMPORTANT LISTS FOR INTERMEDIATE COMPUTATIONS | |
encInputVectors = sentences_Embedded[0] | |
hiddenVectors = [] | |
decInputVectors = startEmbedding[0] | |
# ------------> VARIABLES DEFINITION | |
seqLen, inputDimension = encInputVectors.shape[0], encInputVectors.shape[1] | |
hiddenDimension = 4 | |
cellHeight = 6 | |
cellWidth = 6 | |
inputWords = ['Attention', 'is', 'not', 'enough'] | |
outputWords = ['<start>', 'dikkat', 'yegerli', 'yetil', '<end>'] | |
totalLen = len(inputWords) + len(outputWords) | |
# --------------> REFERENCE LEDGER | |
ledger = [] | |
ledger.append([getCell(h= cellHeight/4, w= cellWidth/4, c= GREEN_D), MathTex('Decoder \: Cell', font_size = 96)]) | |
ledger.append([getCell(h= cellHeight/4, w= cellWidth/4, c= BLUE_C), MathTex('Encoder \: Cell', font_size = 96)]) | |
ledger.append([MathTex("h_{i}", font_size=96), MathTex('Encoder \: Hidden \: States', font_size = 84)]) | |
ledger.append([MathTex("s_{i}", font_size=96), MathTex('Decoder \: Hidden \: States', font_size = 84)]) | |
ledger.append([MathTex("\\tilde{s}_{i}", font_size=96), MathTex('Attentional \: States \: of \: Decoder', font_size = 84)]) | |
ledger.append([Line(start=LEFT, end=RIGHT*2, stroke_width = 15, color = RED), MathTex('Weights \: And \: Parameters', font_size = 84)]) | |
ledger.append([Line(start=LEFT, end=RIGHT*2, stroke_width = 15, color = YELLOW), MathTex('Encoder-Decoder \: Hidden \: States', font_size = 84)]) | |
ledger.append([Line(start=LEFT, end=RIGHT*2, stroke_width = 15, color = GREEN), MathTex('Encoder-Decoder \: Input \: Vectors', font_size = 84)]) | |
ledger.append([Line(start=LEFT, end=RIGHT*2, stroke_width = 15, color = PURPLE), MathTex('Output \: Vectors', font_size = 84)]) | |
ledger.append([Line(start=LEFT, end=RIGHT*2, stroke_width = 15, color = MAROON_A), MathTex('Attentional \: Vectors', font_size = 84)]) | |
LedgerTable = MobjectTable(ledger).to_edge(RIGHT*10 + DOWN*7) | |
referenceTable = Text("Shape and Color reference table", font_size=96, slant=OBLIQUE).next_to(LedgerTable, UP*2) | |
# ------------> iHS : INITIAL INPUT HIDDEN STATE | |
iHSMatrix, iniHS = buildMatrices(Rs= hiddenDimension, Cs= 1, roundDecimals= 2, func= t.zeros, hgt= cellHeight-1, clr= YELLOW, cap= 'Initial \: \n Hidden \: \n State') | |
HiddenVectors.add(iHSMatrix) | |
hiddenVectors.append(iniHS) | |
# ------------> iWMatrix : INPUT WEIGHTS MATRIX, hWMatrix : HIDDEN WEIGHTS MATRIX, oWMatrix : OUTPUT WEIGHTS MATRIX | |
uniform = t.distributions.uniform.Uniform(low=t.tensor(-1.0),high=t.tensor(1.0)) | |
iWMatrix, iniIWM = buildMatrices(Rs = hiddenDimension, Cs = inputDimension, roundDecimals=2, func=uniform.sample, clr= RED, cap=' Input \: Weight \: Matrix ', hgt= 5) | |
hWMatrix, iniHWM = buildMatrices(Rs = hiddenDimension, Cs = hiddenDimension, roundDecimals=2, func=uniform.sample, clr= RED, cap=' Hidden \: Weight \: Matrix ', hgt= 5) | |
oWMatrix, iniOWM = buildMatrices(Rs = inputDimension, Cs = hiddenDimension, roundDecimals=2, func=uniform.sample, clr= RED, cap=' Output \: Weight \: Matrix ', hgt= 5) | |
cWMatrix, iniCWM = buildMatrices(Rs = hiddenDimension, Cs = hiddenDimension*2, roundDecimals=2, func=uniform.sample, clr= RED, cap=' Context \: Weight \: Matrix ', hgt= 5) | |
WEIGHTS.add(*[iWMatrix, hWMatrix, oWMatrix, cWMatrix]).arrange(RIGHT, buff=3).to_edge(DOWN*14) | |
# ------------> ENCODER-DECODER CELLS | |
for i in range(0, totalLen-1): | |
colour = GREEN_C if i>=seqLen else BLUE_C | |
Cells.add(getCell(h= cellHeight, w= cellWidth, c= colour)) | |
# ------------> ARROWS AND VECTORS | |
for i in range(0, totalLen-1): | |
mbj = Cells[i][0] | |
initial_mbj = iHSMatrix if i==0 else Cells[i-1][0] | |
InputArrows.add( getArrow(arrowType= 'input', cell= mbj, sWidth= 7, arrowLen= 7) ) | |
if i < len(encInputVectors): | |
m,_ = buildMatrices(Rs= inputDimension, Cs= 1, roundDecimals=2, whereTo= DOWN, referenceMobj= InputArrows[-1], arrayToMatrix= encInputVectors[i].reshape(-1,1), clr = GREEN, hgt= 7, wdt= 2, cap = "%s"%(inputWords[i])) | |
ENCInputVectors.add(m) | |
OutputArrows.add( getArrow(arrowType= 'output', cell= mbj, sWidth= 7, arrowLen= 10) ) | |
HiddenArrows.add( getArrow(arrowType= 'hidden', cell= mbj, initialMobj= initial_mbj, sWidth= 7, arrowLen= 0) ) | |
# ------------> ANIMATIONS AND ARRANGEMENTS | |
HiddenVectors[0].to_edge(LEFT*2) | |
Cells.arrange(RIGHT, buff=10) | |
self.play(DrawBorderThenFill(title), run_time = 2) | |
self.play(Write(subText), run_time = 2) | |
self.play(Write(LedgerTable),Write(referenceTable), run_time=1) | |
self.play(Write(Cells), run_time=3) | |
self.play(Write(HiddenVectors[0]), run_time=1) | |
# self.play(Cells.animate.shift(DOWN*2), HiddenVectors[0].animate.shift(DOWN*2)) | |
self.play(Write(WEIGHTS), run_time=2) | |
WEIGHTS.save_state() | |
# Focusing Camera on Ledger reference table | |
self.camera.frame.save_state() | |
self.play(self.camera.frame.animate.move_to(LedgerTable).set(width=LedgerTable.width * 2, height=LedgerTable.height * 4)) | |
self.wait() | |
self.play(Restore(self.camera.frame)) | |
# Focusing Camera on WEIGHTS | |
self.camera.frame.save_state() | |
self.play(self.camera.frame.animate.move_to(WEIGHTS).set(width=WEIGHTS.width * 5, height=WEIGHTS.height * 6)) | |
self.wait() | |
self.play(Restore(self.camera.frame)) | |
for i in range(0, totalLen-1): | |
print("ITERATION %d"%i) | |
_cell = Cells[i] | |
_hA = HiddenArrows[i] # _hA : _hiddenArrow | |
_hV_For_Calc = hiddenVectors[i] # _hV_For_Calc : hiddenVector_For_Calculation | |
_hV = HiddenVectors[i] # _hV : _hiddenVector | |
_iA = InputArrows[i] # _iA : _inputArrow | |
_oA = OutputArrows[i] # _oA : _outputArrow | |
if i < len(inputWords): | |
hidText = 'h_{%d}'%i | |
_iV_For_Calc = encInputVectors[i] # _iV_For_Calc : _inputVector_For_Calculation | |
_iV = ENCInputVectors[i] # _iV : _inputVector | |
else: | |
ind = i - len(inputWords) | |
hidText = 's_{%d}'%ind | |
_iV_For_Calc = decInputVectors[-1] # _iV_For_Calc : _inputVector_For_Calculation | |
_iV, _ = buildMatrices(Rs= inputDimension, Cs= 1, roundDecimals=2, whereTo= DOWN, referenceMobj= InputArrows[i], arrayToMatrix= _iV_For_Calc.reshape(-1,1), clr = GREEN, hgt= 7, wdt= 2, cap = "%s"%(outputWords[ind])) # _iV : _inputVector | |
DECInputVectors.add(_iV) | |
# Write Input Arrow of the cell | |
self.play(Write(_iA), run_time = 0.5) | |
# Show Indication of previous output going into next input | |
if i > len(inputWords) and i != len(Cells): | |
_tempArr = Arrow(start = DECOutputVectors[-1].get_critical_point(DR), end = InputArrows[i].get_bottom(), stroke_width=7, max_tip_length_to_length_ratio=0.05) | |
writeUnwrite(mob = _tempArr, runtime=0.5) | |
# Write current Cell Input Embedding Vector | |
self.play(Write(_iV), run_time = 0.5) | |
# Write Hidden Arrow of the cell | |
self.play(Write(_hA), run_time = 0.5) | |
# Bring the input weights closer to current input embedding | |
self.play(iWMatrix.animate.next_to(_iV,LEFT*3), run_time=1) | |
# Bring the hidden weights closer to current hidden state | |
self.play(hWMatrix.animate.next_to(_hA,UP*2), run_time=1) | |
# Focusing Camera on CELLS --------- ZOOMING IN | |
self.camera.frame.save_state() | |
self.play(self.camera.frame.animate.move_to(_cell).set(width=_cell.width * 1.5, height=_cell.height * 7)) | |
# Calculate current input weights multiplied by input vector, current hidden weights multiplied by previous hidden vector THEN indiacte and transform | |
_inV_inW, _inV_inW_Vec = calcAndBuild(op=t.matmul, m1= iniIWM, m2= _iV_For_Calc.reshape(-1,1), rows= hiddenDimension, vecHeight= cellHeight-1, vecWidth=1.5, refMob=_cell, position= RIGHT, shiftBy=LEFT/1.75, toIndicate = [iWMatrix, _iV], colr=YELLOW) | |
_hV_hW, _hV_hW_Vec = calcAndBuild(op=t.matmul, m1= iniHWM, m2= _hV_For_Calc.reshape(-1,1), rows= hiddenDimension, vecHeight= cellHeight-1, vecWidth=1.5, refMob=_cell, position= LEFT, shiftBy=RIGHT/1.75, toIndicate = [hWMatrix, _hV], colr=YELLOW) | |
# Create Plus Sign to show the addition of the intermediate hidden states coming in from Input Weights and Hidden Weights | |
_plusSign = MathTex("+",font_size = 96).move_to(_cell, RIGHT + LEFT) | |
self.play(Write(_plusSign), run_time=0.5) | |
# Calculate & build final hidden state by adding intermediate hidden states | |
_hState, _hState_Vec = calcAndBuild(op=t.add, m1= _inV_inW, m2= _hV_hW, rows= hiddenDimension, vecHeight= cellHeight-1, vecWidth=1.5, refMob=_cell, position= RIGHT+LEFT, toIndicate = [_inV_inW_Vec, _hV_hW_Vec, _plusSign], colr=YELLOW) | |
_hState_Vec.move_to(_cell, RIGHT+LEFT) | |
# Build tan activation mobject of final hidden state vec | |
_paren_hState_Vec = parenthesizeMatrix(matrix=_hState_Vec, encapsulateName='Tan\:h', referenceMobject= _cell, position=RIGHT+LEFT, bf=0.5) | |
_tan_hState, _tan_hState_Vec = calcAndBuild(op= t.nn.Tanh(), m1=_hState, m2= None, rows= hiddenDimension, vecHeight=cellHeight-1, vecWidth=1.5, refMob=_cell, position=RIGHT+LEFT, toIndicate= [_paren_hState_Vec], cp = hidText, colr=YELLOW) | |
_tan_hState_Vec.add_updater(lambda x,y=_cell: x.move_to(y, RIGHT+LEFT)) | |
# remove temporary objects that are no more on the screen | |
self.remove(_inV_inW_Vec, _plusSign, _hV_hW_Vec) | |
# Append necessary temporary variables of this iteration to global variables that can be used for next consequent iteration | |
hiddenVectors.append(_tan_hState) | |
HiddenVectors.add(_tan_hState_Vec) | |
# Focusing Camera on CELLS --------- ZOOMING OUT | |
self.play(Restore(self.camera.frame)) | |
if i < len(inputWords): | |
continue | |
# Write Output Arrow of the cell | |
self.play(Write(_oA), run_time = 0.5) | |
# Pay Attention to the past words properly | |
self.camera.frame.save_state() | |
_, _cV_CW_forCalc, _cV_CW_Vec = payAttention(eHVForCalc=hiddenVectors[1:len(inputWords)+1], EHV= HiddenVectors[1:len(inputWords)+1], curHidStateForCalc=_tan_hState, curHidState=_tan_hState_Vec, contextWeightsForCalc=iniCWM, contextWeightsVec=cWMatrix, referenceMobjectFinal=_oA, finalCap="\\tilde{s}_{%d}"%ind, colr= MAROON_A) | |
# Bring Output weights closer to output arrow | |
self.play(oWMatrix.animate.next_to(_oA, LEFT*2 + UP//1.25), run_time=1) | |
# Focusing Camera on CELLS --------- ZOOMING IN on OUTPUT PART OF DENSE LAYERS OF DECODER | |
self.camera.frame.save_state() | |
self.play(self.camera.frame.animate.move_to(_oA).set(width=_cell.width * 1.5, height=_oA.height * 5)) | |
# Calculate Output, Write after output arrow, shift down to next input arrow and add updater to it | |
_oV_oW, _oV_oW_Vec = calcAndBuild(op= t.matmul, m1= iniOWM, m2= _cV_CW_forCalc.reshape(-1,1), rows= inputDimension, vecHeight= 7, vecWidth=2, refMob=_oA, position= UP, toIndicate = [oWMatrix, _cV_CW_Vec], shiftBy=UP*4, colr=PURPLE) | |
_oV_oW_Vec.add_updater(lambda x, y = _oA: x.move_to(y, UP).shift(UP*4)) | |
decInputVectors = t.cat((decInputVectors, _oV_oW.reshape(1,-1))) | |
DECOutputVectors.add(_oV_oW_Vec) | |
# Demostration of dense layer converting output vector to softmax probabilities of vocabulary length | |
denseLayer.next_to(_oV_oW_Vec, UP) | |
writeUnwrite(mob= denseLayer, runtime=1) | |
# Focusing Camera on CELLS --------- ZOOMING OUT | |
self.play(Restore(self.camera.frame)) | |
print("ITERATION %d COMPLETE"%i) | |
self.play(WEIGHTS.animate.restore()) | |
self.play(Cells.animate.shift(UP*2), HiddenVectors[0].animate.shift(UP*2)) | |
self.wait() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment