Skip to content

Instantly share code, notes, and snippets.

@arif9799
Last active August 8, 2023 02:17
Show Gist options
  • Save arif9799/d0cabd2bcf080a9218a7becc34d14777 to your computer and use it in GitHub Desktop.
Save arif9799/d0cabd2bcf080a9218a7becc34d14777 to your computer and use it in GitHub Desktop.
Attention is not enough! Recurrent Neural Nets Worked-out Example
from manim import *
from manim.utils.unit import Percent, Pixels
from colour import Color
import gensim
from gensim.models import Word2Vec
import numpy as np
import torch as t
config.frame_width = 146
config.frame_height = 85.5
HIGH = True
if HIGH:
print("Its high")
config.pixel_height = 2160
config.pixel_width = 3840
else:
print("Its low")
config.pixel_height = 1080
config.pixel_width = 1920
#************************************************************************************************************************************************************************************
# STRUCTURE OF THIS CODE FILE
#---- def construct_Word_Embeddings() : Function to create Word Embeddings or Vectors for animations using gensim
#---- def getCell() : function to build rectangle shapes that portrays the encoder decoder cells
#---- def getArrow() : function that builds arrows to and from Mobjects indicating flow
#---- def buildMatrices() : function that takes in appropriate input (tensors or functions to generate tensors) and returns a full packaged matrix with dimension and caption as a mobject
#---- def ellipsisDenseLayer() : static function that builds a Dense Layer Classifier components VGroup of neurons and weights that can be simply written on top of a vector to indicate classification
#---- Class FIgure8(MovingCamderaScene) : the main block of rnn animation code
#---- def construct(): Function that Scene class calls automatically to render
#---- def writeUnwrite(): Simulates the effect of show creation and then Fade Out
#---- def calcAndBuild(): Function that does matrix multiplication calculation of weights, indicates the mobjects of interest and transforms them or their copies into results
#---- def parenthesizeMatrix(): creates bracket around mobject to represent as function of tangent, softmax or any other
#************************************************************************************************************************************************************************************
################################################################################
# Creating Word Vectors #
################################################################################
def construct_Word_Embeddings(sentences, dim):
data = sentences
dimension = dim
model = Word2Vec(data, min_count = 1,vector_size = dimension, window = 3)
sequences = []
for seq in data:
sequence = []
for word in seq:
word_emb = model.wv.get_vector(word, True).reshape(1,-1)
word_emb = t.from_numpy(word_emb)
sequence.append(word_emb)
sequences.append(t.round(t.cat(sequence, dim=0), decimals=2))
return sequences
################################################################################
# FUNCTION TO CREATE ENCODER CELLS #
################################################################################
def getCell(w, h, c, fo = 0.1):
"""
w: width of rectangle cell
h: height of rectangle cell
fo: fill opacity of the rectangle cell
txt of cell
"""
return Rectangle(height= h, width= w, color= c, fill_opacity= fo)
################################################################################
# FUNCTION TO CREATE ALL ARROW TYPES #
################################################################################
def getArrow(arrowType, cell, sWidth, arrowLen, initialMobj = None, bf = 0, rt = 0.1):
"""
arrowType: What type of arrow is it, Going into main mobject ('input'), coming out of main mobject ('output') or between two mobjects ('hidden')
cell: the maain mobject rectangle cell or the mobject on right
sWidth: Stroke width or thickness of the arrow
arrowLen: Length of the Arrow
initialMobj: Mobject on the left from where the arrow starts to grow
bf: buff
rt: max_tip_length_to_length_ratio
"""
arr = Arrow(stroke_width=sWidth, buff=bf, max_tip_length_to_length_ratio=rt)
if arrowType == 'input':
arr.add_updater(
lambda x,y=cell: x.become(
Arrow(start = y.get_bottom() + DOWN * arrowLen, end = y.get_bottom(), stroke_width=sWidth, buff=bf, max_tip_length_to_length_ratio=rt)))
elif arrowType == 'output':
arr.add_updater(
lambda x,y=cell: x.become(
Arrow(start = y.get_top(), end = y.get_top() + UP * arrowLen, stroke_width=sWidth, buff=bf, max_tip_length_to_length_ratio=rt)))
elif arrowType == 'hidden':
arr.add_updater(
lambda x,i = initialMobj, f = cell: x.become(
Arrow(start = i.get_right(), end = f.get_left(), stroke_width=sWidth, buff=bf, max_tip_length_to_length_ratio=rt)))
else:
arr = None
return arr
################################################################################
# FUNCTION TO BUILD MATRICES WITH DIMENSIONS AND CAPTIONS #
################################################################################
def buildMatrices(Rs, Cs, roundDecimals, whereTo = None, referenceMobj = None, arrayToMatrix = None, func = None, clr = WHITE, cap = None, hgt = 6, wdt = 1.5, seed = 786):
"""
Rs: number of rows of matrix to build
Cs: number of cols of matrix to build
roundDecimals: number of places to round off the values to
whereTo: Where to exactly place the matrix near the reference mobject (DOWN, UP, LEFT or RIGHT)
referenceMobj: The Mobject w.r.t which the matrix is to be placed on screen and add updater of it
arrayToMatrix: The array whose matrix Mobject is to be built
func: if not an array (if arrayToMatrix is None), the function of torch using which array is to be generated to build matrix mobject
hgt: height of Matrix
wdt: width of Matrix
seed: the value to maintain stability of randomness
cap: Caption or name of the matrix
clr: color of the matrix
"""
if arrayToMatrix is None and func is None:
print("No Input to process and build a Matrix upon")
return None
# ------------> Set a random seed for weights initialization and build a matrix
t.manual_seed(seed=seed)
ini = t.round(func((Rs, Cs)), decimals=roundDecimals) if (arrayToMatrix is None) else t.round(arrayToMatrix,decimals= roundDecimals)
iniMat = Matrix(ini.numpy())
iniMat.height = hgt
iniMat.width = wdt * Cs
iniMat.color = clr
# ------------> locate the matrix somewhere w.r.t reference mobject
if referenceMobj is None:
pass
else:
iniMat.next_to(referenceMobj, whereTo)
iniMat.add_updater( lambda x, y = referenceMobj: x.next_to(y, whereTo))
# ------------> Create the MASTER MATRIX, the VGroup and add the main matrix along with dimension and caption as text
iniMatrix = VGroup()
iniMatrix.add(iniMat)
dimText = MathTex('(%d, %d)'%(Rs, Cs), font_size = 40).move_to(iniMat.get_critical_point(DR)+DOWN/2)
dimText.add_updater(lambda x, y = iniMat: x.move_to(y.get_critical_point(DR)+DOWN/2))
iniMatrix.add(dimText)
if cap is None:
pass
else:
capText = MathTex('%s'%(cap), font_size = 64).move_to(iniMat.get_critical_point(DOWN) + DOWN*1.25 )
capText.add_updater(lambda x, y = iniMat: x.move_to(y.get_critical_point(DOWN) + DOWN*1.25 ))
iniMatrix.add(capText)
return iniMatrix, ini
################################################################################
# FUNCTION TO BUILD DENSE LAYER DEMO FOR OUTPUT #
################################################################################
def ellipsisDenseLayer(n, r, sW, c=GOLD):
"""
ref_vec: mobject from where the transition occurs
n: number of mobjects to be added in ellipsis formation
r: radius of neurons of dense layer
sW: width of arrow transition
"""
tipLengthRatio = 0.05
eDL = VGroup()
eDL.add(*[Circle(radius= r, color=c) for i in range(n-1)])
eDL.add(*[Dot(color=c) for i in range(3)])
eDL.add(Circle(radius=r, color=c))
eDL.arrange(RIGHT, buff=0.15)
softmaxProbs = Text("Softmax Probabilities Vector", slant=OBLIQUE, font_size=48, color=DARK_BROWN).next_to(eDL, UP*10)
ar = VGroup()
arrToDenseStartPosition =eDL.get_bottom() + DOWN*4
denseToTextEndPosition =softmaxProbs.get_bottom()
ar.add(*[ Arrow(start= arrToDenseStartPosition, end= eDL[i].get_bottom(), stroke_width=sW, max_tip_length_to_length_ratio=tipLengthRatio) for i in range(n-1)])
ar.add(Arrow(start= arrToDenseStartPosition, end= eDL[-1].get_bottom(), stroke_width=sW, max_tip_length_to_length_ratio=tipLengthRatio))
ar.add(*[ Arrow(start= eDL[i].get_top(), end= denseToTextEndPosition, stroke_width=sW, max_tip_length_to_length_ratio=tipLengthRatio) for i in range(n-1)])
ar.add(Arrow(start= eDL[-1].get_top(), end= denseToTextEndPosition, stroke_width=sW, max_tip_length_to_length_ratio=tipLengthRatio))
eDL.add(Text("Classifier",slant=OBLIQUE, font_size=48, color=DARK_BLUE).next_to(eDL[-1], RIGHT*2))
eDL+=softmaxProbs
eDL+=ar
return eDL
##############################################################################################################################################################
# |-----| |\ | |\ | #
# | | | \ | | \ | #
# |-----| | \ | | \ | #
# |\ | \ | | \ | #
# | \ | \| | \| #
##############################################################################################################################################################
class Figure7(MovingCameraScene):
def construct(self):
################################################################################
# INTRO ANIMATION SAME ACROSS ALL VIDEOS #
################################################################################
def myAnimation(wordsForIntro: str):
fontHeight = config.frame_height//8
fontColor = WHITE
timePerChar = 0.1
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3)
self.play(Broadcast(C), run_time=1)
self.add(C)
# Building text objects of individual characters.
wordsMath = VGroup()
for word in wordsForIntro:
charTex = VGroup()
for i,ch in enumerate(word):
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight)
if i != 0:
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN)
else:
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN)
charTex.add(chTex)
wordsMath.add(charTex)
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION
for wInd in range(len(wordsMath)):
for chInd in range(len(wordsMath[wInd])):
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar))
self.wait(0.5)
for chInd in reversed(range(len(wordsMath[wInd]))):
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar))
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//3, run_time=1.5))
self.play(ShrinkToCenter(C, run_time=0.25))
self.wait(0.5)
myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE'])
################################################################################
# ALTERNATIVE FUNCTION TO CREATE/UNCREATE OBJECT #
################################################################################
def writeUnwrite(mob,runtime):
"""
mob: Mobject you want to create and remove immediately
runtime: duration of the effect
"""
mob.save_state()
self.play(Write(mob), run_time=runtime)
self.wait(0.25)
self.play(Unwrite(mob), run_time=runtime)
mob.restore()
################################################################################
# VECTOR CALCULATION AND MOBJECT BUILD #
################################################################################
def calcAndBuild(op, m1, m2, rows, vecHeight, vecWidth, refMob, position, toIndicate, colr, shiftBy=[0,0,0], cp= None):
if op == t.matmul:
v = op(m1, m2)
initial = VGroup(*toIndicate).copy()
elif op == t.add:
v = op(m1, m2)
initial = VGroup(*toIndicate)
elif m2==None:
v = op(m1)
initial = VGroup(*toIndicate)
V, _ = buildMatrices(Rs= rows, Cs= 1, roundDecimals=2, whereTo= None, referenceMobj= None, arrayToMatrix= v.reshape(-1,1), clr = colr, hgt= vecHeight, wdt= vecWidth, cap=cp)
V.move_to(refMob, position).shift(shiftBy)
ag = AnimationGroup(*[Indicate(o) for o in toIndicate], lag_ratio=0)
self.play(ag, run_time=0.75)
self.play(FadeTransformPieces(initial, V), run_time=1.25)
return v, V
################################################################################
# FUNCTION TO PARENTHESISE MATRICES #
################################################################################
def parenthesizeMatrix(matrix, encapsulateName: str, referenceMobject: Mobject = None, position = None, bf=0.5, initial_scale_factor: float = 1.5):
"""
matrix: VGroup that contains matrix, dimension and caption
encapsulationName: Name of function in which matrix needs to be presented like sine, cosine, etc
cl: cell in which the parenthesized matrix needs to fit in
"""
parenMat = VGroup()
parens = MathTex("(", ")")
parens.scale(initial_scale_factor)
parens.stretch_to_fit_height(matrix.height)
mt= matrix[0]
l_paren, r_paren = parens.split()
l_paren.next_to(mt, LEFT, buff=bf)
r_paren.next_to(mt, RIGHT, buff=bf)
tx = MathTex(encapsulateName)
tx.scale(initial_scale_factor)
tx.next_to(l_paren, LEFT, buff=bf)
parenMat+=matrix
parenMat.add(*[l_paren, r_paren, tx])
self.play(FadeTransformPieces(matrix, parenMat), run_time=0.75)
if referenceMobject is not None:
self.play(parenMat.animate.move_to(referenceMobject, position))
return parenMat
#******************************************************************************************************************************************************************
# MAIN LOGIC STARTS HERE!!!!!!!! #
#******************************************************************************************************************************************************************
# ---------------> TITLE OF VIDEO
title = Title('Vanilla RNN without Attention', match_underline_width_to_text=True, underline_buff= MED_LARGE_BUFF, font_size = 300).shift(DOWN*7)
subText = MathTex('Uisng \: concept \: of \: tying \: weights', font_size = 150).next_to(title, DOWN*2)
# ---------------> VARIABLES BELOW ARE TENSORS
embeddingDimensions = 3
data = [['Attention', 'is', 'not', 'enough'], ['Towards', 'Data', 'Science'], ['It', 'was', 'a', 'really', 'good', 'article']]
sentences_Embedded = construct_Word_Embeddings(data, embeddingDimensions)
startEmbedding = construct_Word_Embeddings([['<start>']], embeddingDimensions)
# ------------> IMPORTANT GROUPS
WEIGHTS = VGroup()
Cells = VGroup()
InputArrows = VGroup()
OutputArrows = VGroup()
HiddenArrows = VGroup()
ENCInputVectors = VGroup()
HiddenVectors = VGroup()
DECInputVectors = VGroup()
DECOutputVectors = VGroup()
denseLayer = ellipsisDenseLayer(n=5, r=0.2, sW=2)
# ------------> IMPORTANT LISTS FOR INTERMEDIATE COMPUTATIONS
encInputVectors = sentences_Embedded[0]
hiddenVectors = []
decInputVectors = startEmbedding[0]
# ------------> VARIABLES DEFINITION
seqLen, inputDimension = encInputVectors.shape[0], encInputVectors.shape[1]
hiddenDimension = 4
cellHeight = 6
cellWidth = 6
inputWords = ['Attention', 'is', 'not', 'enough']
outputWords = ['<start>', 'dikkat', 'yegerli', 'yetil', '<end>']
totalLen = len(inputWords) + len(outputWords)
# --------------> REFERENCE LEDGER
ledger = []
ledger.append([getCell(h= cellHeight/4, w= cellWidth/4, c= GREEN_D), MathTex('Decoder \: Cell', font_size = 96)])
ledger.append([getCell(h= cellHeight/4, w= cellWidth/4, c= BLUE_C), MathTex('Encoder \: Cell', font_size = 96)])
ledger.append([MathTex("h_{i}", font_size=96), MathTex('Encoder \: Hidden \: States', font_size = 84)])
ledger.append([MathTex("s_{i}", font_size=96), MathTex('Decoder \: Hidden \: States', font_size = 84)])
ledger.append([Line(start=LEFT, end=RIGHT*2, stroke_width = 15, color = RED), MathTex('Weights \: And \: Parameters', font_size = 84)])
ledger.append([Line(start=LEFT, end=RIGHT*2, stroke_width = 15, color = YELLOW), MathTex('Encoder-Decoder \: Hidden \: States', font_size = 84)])
ledger.append([Line(start=LEFT, end=RIGHT*2, stroke_width = 15, color = GREEN), MathTex('Encoder-Decoder \: Input \: Vectors', font_size = 84)])
ledger.append([Line(start=LEFT, end=RIGHT*2, stroke_width = 15, color = PURPLE), MathTex('Output \: Vectors', font_size = 84)])
LedgerTable = MobjectTable(ledger).to_edge(RIGHT*10 + DOWN*7)
referenceTable = Text("Shape and Color reference table", font_size=96, slant=OBLIQUE).next_to(LedgerTable, UP*2)
# ------------> iHS : INITIAL INPUT HIDDEN STATE
iHSMatrix, iniHS = buildMatrices(Rs= hiddenDimension, Cs= 1, roundDecimals= 2, func= t.zeros, hgt= cellHeight-1, clr= YELLOW, cap= 'Initial \: \n Hidden \: \n State')
HiddenVectors.add(iHSMatrix)
hiddenVectors.append(iniHS)
# ------------> iWMatrix : INPUT WEIGHTS MATRIX, hWMatrix : HIDDEN WEIGHTS MATRIX, oWMatrix : OUTPUT WEIGHTS MATRIX
uniform = t.distributions.uniform.Uniform(low=t.tensor(-1.0),high=t.tensor(1.0))
iWMatrix, iniIWM = buildMatrices(Rs = hiddenDimension, Cs = inputDimension, roundDecimals=2, func=uniform.sample, clr= RED, cap=' Input \: Weight \: Matrix ', hgt= 5)
hWMatrix, iniHWM = buildMatrices(Rs = hiddenDimension, Cs = hiddenDimension, roundDecimals=2, func=uniform.sample, clr= RED, cap=' Hidden \: Weight \: Matrix ', hgt= 5)
oWMatrix, iniOWM = buildMatrices(Rs = inputDimension, Cs = hiddenDimension, roundDecimals=2, func=uniform.sample, clr= RED, cap=' Output \: Weight \: Matrix ', hgt= 5)
WEIGHTS.add(*[iWMatrix, hWMatrix, oWMatrix]).arrange(RIGHT, buff=3).to_edge(DOWN*14)
# ------------> ENCODER-DECODER CELLS
for i in range(0, totalLen-1):
colour = GREEN_C if i>=seqLen else BLUE_C
Cells.add(getCell(h= cellHeight, w= cellWidth, c= colour))
# ------------> ARROWS AND VECTORS
for i in range(0, totalLen-1):
mbj = Cells[i][0]
initial_mbj = iHSMatrix if i==0 else Cells[i-1][0]
InputArrows.add( getArrow(arrowType= 'input', cell= mbj, sWidth= 7, arrowLen= 7) )
if i < len(encInputVectors):
m,_ = buildMatrices(Rs= inputDimension, Cs= 1, roundDecimals=2, whereTo= DOWN, referenceMobj= InputArrows[-1], arrayToMatrix= encInputVectors[i].reshape(-1,1), clr = GREEN, hgt= 7, wdt= 2, cap = "%s"%(inputWords[i]))
ENCInputVectors.add(m)
OutputArrows.add( getArrow(arrowType= 'output', cell= mbj, sWidth= 7, arrowLen= 10) )
HiddenArrows.add( getArrow(arrowType= 'hidden', cell= mbj, initialMobj= initial_mbj, sWidth= 7, arrowLen= 0) )
# ------------> ANIMATIONS AND ARRANGEMENTS
HiddenVectors[0].to_edge(LEFT*2)
Cells.arrange(RIGHT, buff=10)
self.play(DrawBorderThenFill(title), run_time = 2)
self.play(Write(subText), run_time = 2)
self.play(Write(LedgerTable),Write(referenceTable), run_time=1)
self.play(Write(Cells), run_time=3)
self.play(Write(HiddenVectors[0]), run_time=1)
self.play(Cells.animate.shift(DOWN*7), HiddenVectors[0].animate.shift(DOWN*7))
self.play(Write(WEIGHTS), run_time=2)
WEIGHTS.save_state()
# Focusing Camera on Ledger reference table
self.camera.frame.save_state()
self.play(self.camera.frame.animate.move_to(LedgerTable).set(width=LedgerTable.width * 2, height=LedgerTable.height * 4))
self.wait()
self.play(Restore(self.camera.frame))
# Focusing Camera on WEIGHTS
self.camera.frame.save_state()
self.play(self.camera.frame.animate.move_to(WEIGHTS).set(width=WEIGHTS.width * 3, height=WEIGHTS.height * 4))
self.wait()
self.play(Restore(self.camera.frame))
for i in range(0, totalLen-1):
print("ITERATION %d"%i)
_cell = Cells[i]
_hA = HiddenArrows[i] # _hA : _hiddenArrow
_hV_For_Calc = hiddenVectors[i] # _hV_For_Calc : hiddenVector_For_Calculation
_hV = HiddenVectors[i] # _hV : _hiddenVector
_iA = InputArrows[i] # _iA : _inputArrow
_oA = OutputArrows[i] # _oA : _outputArrow
if i < len(inputWords):
hidText = 'h_{%d}'%i
_iV_For_Calc = encInputVectors[i] # _iV_For_Calc : _inputVector_For_Calculation
_iV = ENCInputVectors[i] # _iV : _inputVector
else:
ind = i - len(inputWords)
hidText = 's_{%d}'%ind
_iV_For_Calc = decInputVectors[-1] # _iV_For_Calc : _inputVector_For_Calculation
_iV, _ = buildMatrices(Rs= inputDimension, Cs= 1, roundDecimals=2, whereTo= DOWN, referenceMobj= InputArrows[i], arrayToMatrix= _iV_For_Calc.reshape(-1,1), clr = GREEN, hgt= 7, wdt= 2, cap = "%s"%(outputWords[ind])) # _iV : _inputVector
DECInputVectors.add(_iV)
# Write Input Arrow of the cell
self.play(Write(_iA), run_time = 0.5)
# Show Indication of previous output going into next input
if i > len(inputWords) and i != len(Cells):
_tempArr = Arrow(start = DECOutputVectors[-1].get_critical_point(DR), end = InputArrows[i].get_bottom(), stroke_width=7, max_tip_length_to_length_ratio=0.05)
writeUnwrite(mob = _tempArr, runtime=0.5)
# Write current Cell Input Embedding Vector
self.play(Write(_iV), run_time = 0.5)
# Write Hidden Arrow of the cell
self.play(Write(_hA), run_time = 0.5)
# Bring the input weights closer to current input embedding
self.play(iWMatrix.animate.next_to(_iV,LEFT*3), run_time=1)
# Bring the hidden weights closer to current hidden state
self.play(hWMatrix.animate.next_to(_hA,UP*2), run_time=1)
# Focusing Camera on CELLS --------- ZOOMING IN
self.camera.frame.save_state()
self.play(self.camera.frame.animate.move_to(_cell).set(width=_cell.width * 1.5, height=_cell.height * 7))
# Calculate current input weights multiplied by input vector, current hidden weights multiplied by previous hidden vector THEN indiacte and transform
_inV_inW, _inV_inW_Vec = calcAndBuild(op=t.matmul, m1= iniIWM, m2= _iV_For_Calc.reshape(-1,1), rows= hiddenDimension, vecHeight= cellHeight-1, vecWidth=1.5, refMob=_cell, position= RIGHT, shiftBy=LEFT/1.75, toIndicate = [iWMatrix, _iV], colr=YELLOW)
_hV_hW, _hV_hW_Vec = calcAndBuild(op=t.matmul, m1= iniHWM, m2= _hV_For_Calc.reshape(-1,1), rows= hiddenDimension, vecHeight= cellHeight-1, vecWidth=1.5, refMob=_cell, position= LEFT, shiftBy=RIGHT/1.75, toIndicate = [hWMatrix, _hV], colr=YELLOW)
# Create Plus Sign to show the addition of the intermediate hidden states coming in from Input Weights and Hidden Weights
_plusSign = MathTex("+",font_size = 96).move_to(_cell, RIGHT + LEFT)
self.play(Write(_plusSign), run_time=0.5)
# Calculate & build final hidden state by adding intermediate hidden states
_hState, _hState_Vec = calcAndBuild(op=t.add, m1= _inV_inW, m2= _hV_hW, rows= hiddenDimension, vecHeight= cellHeight-1, vecWidth=1.5, refMob=_cell, position= RIGHT+LEFT, toIndicate = [_inV_inW_Vec, _hV_hW_Vec, _plusSign], colr=YELLOW)
_hState_Vec.move_to(_cell, RIGHT+LEFT)
# Build tan activation mobject of final hidden state vec
_paren_hState_Vec = parenthesizeMatrix(matrix=_hState_Vec, encapsulateName='Tan\:h', referenceMobject= _cell, position=RIGHT+LEFT, bf=0.5)
_tan_hState, _tan_hState_Vec = calcAndBuild(op= t.nn.Tanh(), m1=_hState, m2= None, rows= hiddenDimension, vecHeight=cellHeight-1, vecWidth=1.5, refMob=_cell, position=RIGHT+LEFT, toIndicate= [_paren_hState_Vec], cp = hidText, colr=YELLOW)
_tan_hState_Vec.add_updater(lambda x,y=_cell: x.move_to(y, RIGHT+LEFT))
# remove temporary objects that are no more on the screen
self.remove(_inV_inW_Vec, _plusSign, _hV_hW_Vec)
# Append necessary temporary variables of this iteration to global variables that can be used for next consequent iteration
hiddenVectors.append(_tan_hState)
HiddenVectors.add(_tan_hState_Vec)
# Focusing Camera on CELLS --------- ZOOMING OUT
self.play(Restore(self.camera.frame))
if i < len(inputWords):
continue
# Write Output Arrow of the cell
self.play(Write(_oA), run_time = 0.5)
# Bring Output weights closer to output arrow
self.play(oWMatrix.animate.next_to(_oA, LEFT*2 + UP//1.25), run_time=1)
# Focusing Camera on CELLS --------- ZOOMING IN on OUTPUT PART OF DENSE LAYERS OF DECODER
self.camera.frame.save_state()
self.play(self.camera.frame.animate.move_to(_oA).set(width=_cell.width * 1.5, height=_oA.height * 5))
# Calculate Output, Write after output arrow, shift down to next input arrow and add updater to it
_oV_oW, _oV_oW_Vec = calcAndBuild(op= t.matmul, m1= iniOWM, m2= _tan_hState.reshape(-1,1), rows= inputDimension, vecHeight= 7, vecWidth=2, refMob=_oA, position= UP, toIndicate = [oWMatrix, _tan_hState_Vec], shiftBy=UP*4, colr=PURPLE)
_oV_oW_Vec.add_updater(lambda x, y = _oA: x.move_to(y, UP).shift(UP*4))
decInputVectors = t.cat((decInputVectors, _oV_oW.reshape(1,-1)))
DECOutputVectors.add(_oV_oW_Vec)
# Demostration of dense layer converting output vector to softmax probabilities of vocabulary length
denseLayer.next_to(_oV_oW_Vec, UP)
writeUnwrite(mob= denseLayer, runtime=1)
# Focusing Camera on CELLS --------- ZOOMING OUT
self.play(Restore(self.camera.frame))
print("ITERATION %d COMPLETE"%i)
self.play(WEIGHTS.animate.restore())
self.play(Cells.animate.shift(UP*2), HiddenVectors[0].animate.shift(UP*2))
self.wait()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment