Last active
August 8, 2023 02:10
-
-
Save arif9799/5031fbde800b1d4f48f62d722a65b0ea to your computer and use it in GitHub Desktop.
Attention is not enough! RNNs to Transformers Transition Demo Animation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from manim import * | |
from manim.utils.unit import Percent, Pixels | |
from colour import Color | |
import gensim | |
from gensim.models import Word2Vec | |
import numpy as np | |
config.frame_width = 48 | |
config.frame_height = 27 | |
HIGH = True | |
if HIGH: | |
print("Its high") | |
config.pixel_height = 2160 | |
config.pixel_width = 3840 | |
else: | |
print("Its low") | |
config.pixel_height = 1080 | |
config.pixel_width = 1920 | |
def construct_Word_Embeddings(sentences, dim): | |
data = sentences | |
dimension = dim | |
model = Word2Vec(data, | |
min_count = 1, | |
vector_size = dimension, | |
window = 3) | |
results = [] | |
for seq in data: | |
sequence = np.empty((0,dimension)) | |
for word in seq: | |
word_emb = model.wv.get_vector(word, True).reshape(1,-1) | |
sequence = np.append(sequence, word_emb, axis = 0) | |
results.append(sequence) | |
return results | |
data = [['Attention', 'is', 'not', 'enough'], ['Towards', 'Data', 'Science'], ['It', 'was', 'a', 'really', 'good', 'article']] | |
sentences_Embedded = construct_Word_Embeddings(data, 6) | |
sentence_embed = np.round(sentences_Embedded[0],2) | |
class Figure4(Scene): | |
def construct(self): | |
################################################################################ | |
# INTRO ANIMATION SAME ACROSS ALL VIDEOS # | |
################################################################################ | |
def myAnimation(wordsForIntro: str): | |
fontHeight = config.frame_height//8 | |
fontColor = WHITE | |
timePerChar = 0.1 | |
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3) | |
self.play(Broadcast(C), run_time=1) | |
self.add(C) | |
# Building text objects of individual characters. | |
wordsMath = VGroup() | |
for word in wordsForIntro: | |
charTex = VGroup() | |
for i,ch in enumerate(word): | |
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight) | |
if i != 0: | |
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN) | |
else: | |
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN) | |
charTex.add(chTex) | |
wordsMath.add(charTex) | |
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION | |
for wInd in range(len(wordsMath)): | |
for chInd in range(len(wordsMath[wInd])): | |
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.wait(0.5) | |
for chInd in reversed(range(len(wordsMath[wInd]))): | |
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//3, run_time=1.5)) | |
self.play(ShrinkToCenter(C, run_time=0.25)) | |
self.wait(0.5) | |
myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE']) | |
################################################################################ | |
# MAIN CODE OF THE FIGURE STARTS HERE # | |
# NO FUNCTION OR ANYTHING OF THAT SORT # | |
################################################################################ | |
# constructing Recurrent Neural Network Encoder | |
enc_cell_width = 5 | |
enc_cell_height = 2.5 | |
strokeWidth = 5 | |
c_col = [GREEN_A,GREEN_E] | |
input_words = ['Attention', 'is', 'not', 'enough'] | |
input_hidden_states = ['h_0','h_1', 'h_2', 'h_3','h_4' ] | |
dimensions = ['d_' + str(i) for i in range(sentence_embed.shape[1])] | |
Cells = VGroup() | |
Cell_Captions = VGroup() | |
Hidden_Arrows = VGroup() | |
Input_Arrows = VGroup() | |
Output_Arrows = VGroup() | |
Hidden_States = VGroup() | |
Words = VGroup() | |
Outright_Mobject = VGroup() | |
Final_Matrix = VGroup() | |
# Constructing and adding each component in order of their creation to VGroup mobject. | |
# ----------------> | |
Word_Vectors = VGroup( | |
*[ Matrix(word.reshape(-1,1), v_buff=0.75, h_buff=0.5).set_column_colors(c_col) for word in sentence_embed] | |
) | |
for w in input_words: | |
# ----------------> | |
Cells.add(Rectangle(height= enc_cell_height, width = enc_cell_width, color = BLUE_C, fill_opacity = 0.25)) | |
Cell_Captions.add(Text('RNN \n Encoder',color = WHITE, font_size = 32, slant=ITALIC).move_to(Cells[-1].get_center()).add_updater(lambda x, cell = Cells[-1]: x.move_to(cell.get_center()))) | |
# ----------------> | |
Hidden_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater( | |
lambda mob, enc = Cells[0]: mob.become(Arrow(start = enc.get_left() + [-3,0,0], end = enc.get_left(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)) | |
)) | |
Outright_Mobject.add(Hidden_Arrows[-1]) | |
# ----------------> | |
Hidden_States.add(MathTex(input_hidden_states[0],color = WHITE, font_size = 64).next_to(Hidden_Arrows[-1], LEFT/4).add_updater(lambda x, y = Hidden_Arrows[-1]: x.next_to(y, LEFT/4))) | |
Outright_Mobject.add(Hidden_States[-1]) | |
for i in range(len(Cells)): | |
c = Cells[i] | |
w = input_words[i] | |
h = input_hidden_states[i+1] | |
wv = Word_Vectors[i] | |
# ----------------> | |
Input_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater( | |
lambda mob, enc = c: mob.become(Arrow(start = enc.get_bottom() + [0,-3,0], end = enc.get_bottom(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)) | |
)) | |
Outright_Mobject.add(Input_Arrows[-1]) | |
# ----------------> | |
wv.add_updater(lambda x, y = Input_Arrows[-1]: x.next_to(y, DOWN/4)) | |
Outright_Mobject.add(wv) | |
# ----------------> | |
Words.add(Text(w,color = WHITE, font_size = 48, slant = ITALIC).next_to(wv, DOWN).add_updater(lambda x, y = wv: x.next_to(y, DOWN))) | |
Outright_Mobject.add(Words[-1]) | |
# ----------------> | |
Output_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater( | |
lambda mob, enc = c[-1]: mob.become(Arrow(start = enc.get_top(), end = enc.get_top() + [0,3,0], buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)) | |
)) | |
Outright_Mobject.add(Output_Arrows[-1]) | |
# ----------------> | |
Hidden_States.add(MathTex(h,color = WHITE, font_size = 64).next_to(Output_Arrows[-1], UP/4).add_updater(lambda x, y = Output_Arrows[-1]: x.next_to(y, UP/4))) | |
Outright_Mobject.add(Hidden_States[-1]) | |
if i < len(Cells)-1: | |
nc = Cells[i+1] | |
# ----------------> | |
Hidden_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater( | |
lambda mob, enc = nc, prev_enc = c: mob.become(Arrow(start = prev_enc.get_right(), end = enc.get_left(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)) | |
)) | |
Outright_Mobject.add(Hidden_Arrows[-1]) | |
# ----------------> | |
Hidden_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater( | |
lambda mob, enc = Cells[-1]: mob.become(Arrow(start = enc.get_right(), end = enc.get_right() + [2,0,0], buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)) | |
)) | |
Outright_Mobject.add(Hidden_Arrows[-1]) | |
# Switching to Transformer | |
TransformerEncoderCell = Rectangle(height=enc_cell_height, width=enc_cell_width*len(Cells), color = YELLOW, fill_opacity = 0.3) | |
TransformerEncoderText = Text('Transformer Encoder.', font_size=96, slant=ITALIC).add_updater(lambda x, c = TransformerEncoderCell: x.move_to(c.center())) | |
TransformerEncoderInputArrow = Arrow(stroke_width=strokeWidth).add_updater(lambda x, c = TransformerEncoderCell: x.become(Arrow(start = c.get_bottom()+ [0,-4,0], end = c.get_bottom(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))) | |
TransformerEncoderOutputArrow = Arrow(stroke_width=strokeWidth).add_updater(lambda x, c = TransformerEncoderCell: x.become(Arrow(start = c.get_top(), end = c.get_top()+ [0,4,0], buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))) | |
TransformerInputMatrix = Matrix(sentence_embed.T, h_buff=2.5).set_column_colors(c_col, c_col, c_col, c_col).move_to(TransformerEncoderInputArrow.get_bottom() + DOWN*3).add_updater(lambda x, arw = TransformerEncoderInputArrow: x.move_to(arw.get_bottom()+ DOWN*3)) | |
TransformerInputMatrix_T = Matrix(sentence_embed, h_buff=2).set_column_colors(c_col, c_col, c_col, c_col, c_col, c_col).move_to(TransformerEncoderInputArrow.get_bottom() + DOWN*3).add_updater(lambda x, arw = TransformerEncoderInputArrow: x.move_to(arw.get_bottom()+ DOWN*3)) | |
TransformerEncoderOutputText = Text('Encoded Output', font_size=64, slant=ITALIC).add_updater(lambda x, c = TransformerEncoderOutputArrow: x.move_to(c.get_top() + UP)) | |
# ----------------> TRANSFORMER INPUT MATRIX ROW NAMES | |
TransformerInputMatrix_TRows = TransformerInputMatrix_T.get_rows() | |
TransformerInputMatrix_TRowNames =[MathTex(input_words[i]).next_to(TransformerInputMatrix_TRows[i], LEFT*5).add_updater(lambda x, y = TransformerInputMatrix_TRows[i]: x.next_to(y, LEFT*5)) for i in range(sentence_embed.shape[0])] | |
# ----------------> TRANSFORMER INPUT MATRIX COLUMN NAMES | |
TransformerInputMatrix_TCols = TransformerInputMatrix_T.get_columns() | |
TransformerInputMatrix_TColNames =[MathTex(dimensions[i]).next_to(TransformerInputMatrix_TCols[i], UP*3).add_updater(lambda x, y = TransformerInputMatrix_TCols[i]: x.next_to(y, UP*3)) for i in range(sentence_embed.shape[1])] | |
# Laying Out All Animations | |
# ----- Pull in all the RNN Encoder Components | |
Cells.arrange(RIGHT, buff=3) | |
self.play(Create(Cells), run_time=2, lag_ratio = 1) | |
self.play(Write(Cell_Captions), run_time=1) | |
self.play(Create(Outright_Mobject), run_time=10) | |
# ----- Converting RNN Encoder Cells to Transformer Encoder Cell | |
self.play(Cells.animate.arrange(RIGHT, buff=0), FadeOut(Cell_Captions)) | |
Cell_Captions.clear_updaters() | |
self.remove(Cell_Captions) | |
self.play(FadeTransform(Cells,TransformerEncoderCell)) | |
Cells.clear_updaters() | |
self.remove(Cells) | |
Outright_Mobject.clear_updaters() | |
self.play( | |
Write(TransformerEncoderText), | |
Write(TransformerEncoderInputArrow), | |
Unwrite(Input_Arrows), | |
Unwrite(Hidden_Arrows), | |
Unwrite(Words), | |
Write(TransformerEncoderOutputArrow), | |
Unwrite(Output_Arrows), | |
lag_ratio = 0.5) | |
self.play(FadeTransformPieces(Word_Vectors, TransformerInputMatrix)) | |
self.play(FadeTransformPieces(Hidden_States, TransformerEncoderOutputText)) | |
self.play(FadeTransformPieces(TransformerInputMatrix, TransformerInputMatrix_T)) | |
self.play(Write(VGroup(*TransformerInputMatrix_TRowNames))) | |
self.play(Write(VGroup(*TransformerInputMatrix_TColNames))) | |
self.wait(10) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment