Skip to content

Instantly share code, notes, and snippets.

@arif9799
Last active August 8, 2023 02:10
Show Gist options
  • Save arif9799/5031fbde800b1d4f48f62d722a65b0ea to your computer and use it in GitHub Desktop.
Save arif9799/5031fbde800b1d4f48f62d722a65b0ea to your computer and use it in GitHub Desktop.
Attention is not enough! RNNs to Transformers Transition Demo Animation
from manim import *
from manim.utils.unit import Percent, Pixels
from colour import Color
import gensim
from gensim.models import Word2Vec
import numpy as np
config.frame_width = 48
config.frame_height = 27
HIGH = True
if HIGH:
print("Its high")
config.pixel_height = 2160
config.pixel_width = 3840
else:
print("Its low")
config.pixel_height = 1080
config.pixel_width = 1920
def construct_Word_Embeddings(sentences, dim):
data = sentences
dimension = dim
model = Word2Vec(data,
min_count = 1,
vector_size = dimension,
window = 3)
results = []
for seq in data:
sequence = np.empty((0,dimension))
for word in seq:
word_emb = model.wv.get_vector(word, True).reshape(1,-1)
sequence = np.append(sequence, word_emb, axis = 0)
results.append(sequence)
return results
data = [['Attention', 'is', 'not', 'enough'], ['Towards', 'Data', 'Science'], ['It', 'was', 'a', 'really', 'good', 'article']]
sentences_Embedded = construct_Word_Embeddings(data, 6)
sentence_embed = np.round(sentences_Embedded[0],2)
class Figure4(Scene):
def construct(self):
################################################################################
# INTRO ANIMATION SAME ACROSS ALL VIDEOS #
################################################################################
def myAnimation(wordsForIntro: str):
fontHeight = config.frame_height//8
fontColor = WHITE
timePerChar = 0.1
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3)
self.play(Broadcast(C), run_time=1)
self.add(C)
# Building text objects of individual characters.
wordsMath = VGroup()
for word in wordsForIntro:
charTex = VGroup()
for i,ch in enumerate(word):
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight)
if i != 0:
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN)
else:
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN)
charTex.add(chTex)
wordsMath.add(charTex)
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION
for wInd in range(len(wordsMath)):
for chInd in range(len(wordsMath[wInd])):
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar))
self.wait(0.5)
for chInd in reversed(range(len(wordsMath[wInd]))):
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar))
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//3, run_time=1.5))
self.play(ShrinkToCenter(C, run_time=0.25))
self.wait(0.5)
myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE'])
################################################################################
# MAIN CODE OF THE FIGURE STARTS HERE #
# NO FUNCTION OR ANYTHING OF THAT SORT #
################################################################################
# constructing Recurrent Neural Network Encoder
enc_cell_width = 5
enc_cell_height = 2.5
strokeWidth = 5
c_col = [GREEN_A,GREEN_E]
input_words = ['Attention', 'is', 'not', 'enough']
input_hidden_states = ['h_0','h_1', 'h_2', 'h_3','h_4' ]
dimensions = ['d_' + str(i) for i in range(sentence_embed.shape[1])]
Cells = VGroup()
Cell_Captions = VGroup()
Hidden_Arrows = VGroup()
Input_Arrows = VGroup()
Output_Arrows = VGroup()
Hidden_States = VGroup()
Words = VGroup()
Outright_Mobject = VGroup()
Final_Matrix = VGroup()
# Constructing and adding each component in order of their creation to VGroup mobject.
# ---------------->
Word_Vectors = VGroup(
*[ Matrix(word.reshape(-1,1), v_buff=0.75, h_buff=0.5).set_column_colors(c_col) for word in sentence_embed]
)
for w in input_words:
# ---------------->
Cells.add(Rectangle(height= enc_cell_height, width = enc_cell_width, color = BLUE_C, fill_opacity = 0.25))
Cell_Captions.add(Text('RNN \n Encoder',color = WHITE, font_size = 32, slant=ITALIC).move_to(Cells[-1].get_center()).add_updater(lambda x, cell = Cells[-1]: x.move_to(cell.get_center())))
# ---------------->
Hidden_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater(
lambda mob, enc = Cells[0]: mob.become(Arrow(start = enc.get_left() + [-3,0,0], end = enc.get_left(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))
))
Outright_Mobject.add(Hidden_Arrows[-1])
# ---------------->
Hidden_States.add(MathTex(input_hidden_states[0],color = WHITE, font_size = 64).next_to(Hidden_Arrows[-1], LEFT/4).add_updater(lambda x, y = Hidden_Arrows[-1]: x.next_to(y, LEFT/4)))
Outright_Mobject.add(Hidden_States[-1])
for i in range(len(Cells)):
c = Cells[i]
w = input_words[i]
h = input_hidden_states[i+1]
wv = Word_Vectors[i]
# ---------------->
Input_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater(
lambda mob, enc = c: mob.become(Arrow(start = enc.get_bottom() + [0,-3,0], end = enc.get_bottom(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))
))
Outright_Mobject.add(Input_Arrows[-1])
# ---------------->
wv.add_updater(lambda x, y = Input_Arrows[-1]: x.next_to(y, DOWN/4))
Outright_Mobject.add(wv)
# ---------------->
Words.add(Text(w,color = WHITE, font_size = 48, slant = ITALIC).next_to(wv, DOWN).add_updater(lambda x, y = wv: x.next_to(y, DOWN)))
Outright_Mobject.add(Words[-1])
# ---------------->
Output_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater(
lambda mob, enc = c[-1]: mob.become(Arrow(start = enc.get_top(), end = enc.get_top() + [0,3,0], buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))
))
Outright_Mobject.add(Output_Arrows[-1])
# ---------------->
Hidden_States.add(MathTex(h,color = WHITE, font_size = 64).next_to(Output_Arrows[-1], UP/4).add_updater(lambda x, y = Output_Arrows[-1]: x.next_to(y, UP/4)))
Outright_Mobject.add(Hidden_States[-1])
if i < len(Cells)-1:
nc = Cells[i+1]
# ---------------->
Hidden_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater(
lambda mob, enc = nc, prev_enc = c: mob.become(Arrow(start = prev_enc.get_right(), end = enc.get_left(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))
))
Outright_Mobject.add(Hidden_Arrows[-1])
# ---------------->
Hidden_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater(
lambda mob, enc = Cells[-1]: mob.become(Arrow(start = enc.get_right(), end = enc.get_right() + [2,0,0], buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))
))
Outright_Mobject.add(Hidden_Arrows[-1])
# Switching to Transformer
TransformerEncoderCell = Rectangle(height=enc_cell_height, width=enc_cell_width*len(Cells), color = YELLOW, fill_opacity = 0.3)
TransformerEncoderText = Text('Transformer Encoder.', font_size=96, slant=ITALIC).add_updater(lambda x, c = TransformerEncoderCell: x.move_to(c.center()))
TransformerEncoderInputArrow = Arrow(stroke_width=strokeWidth).add_updater(lambda x, c = TransformerEncoderCell: x.become(Arrow(start = c.get_bottom()+ [0,-4,0], end = c.get_bottom(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)))
TransformerEncoderOutputArrow = Arrow(stroke_width=strokeWidth).add_updater(lambda x, c = TransformerEncoderCell: x.become(Arrow(start = c.get_top(), end = c.get_top()+ [0,4,0], buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)))
TransformerInputMatrix = Matrix(sentence_embed.T, h_buff=2.5).set_column_colors(c_col, c_col, c_col, c_col).move_to(TransformerEncoderInputArrow.get_bottom() + DOWN*3).add_updater(lambda x, arw = TransformerEncoderInputArrow: x.move_to(arw.get_bottom()+ DOWN*3))
TransformerInputMatrix_T = Matrix(sentence_embed, h_buff=2).set_column_colors(c_col, c_col, c_col, c_col, c_col, c_col).move_to(TransformerEncoderInputArrow.get_bottom() + DOWN*3).add_updater(lambda x, arw = TransformerEncoderInputArrow: x.move_to(arw.get_bottom()+ DOWN*3))
TransformerEncoderOutputText = Text('Encoded Output', font_size=64, slant=ITALIC).add_updater(lambda x, c = TransformerEncoderOutputArrow: x.move_to(c.get_top() + UP))
# ----------------> TRANSFORMER INPUT MATRIX ROW NAMES
TransformerInputMatrix_TRows = TransformerInputMatrix_T.get_rows()
TransformerInputMatrix_TRowNames =[MathTex(input_words[i]).next_to(TransformerInputMatrix_TRows[i], LEFT*5).add_updater(lambda x, y = TransformerInputMatrix_TRows[i]: x.next_to(y, LEFT*5)) for i in range(sentence_embed.shape[0])]
# ----------------> TRANSFORMER INPUT MATRIX COLUMN NAMES
TransformerInputMatrix_TCols = TransformerInputMatrix_T.get_columns()
TransformerInputMatrix_TColNames =[MathTex(dimensions[i]).next_to(TransformerInputMatrix_TCols[i], UP*3).add_updater(lambda x, y = TransformerInputMatrix_TCols[i]: x.next_to(y, UP*3)) for i in range(sentence_embed.shape[1])]
# Laying Out All Animations
# ----- Pull in all the RNN Encoder Components
Cells.arrange(RIGHT, buff=3)
self.play(Create(Cells), run_time=2, lag_ratio = 1)
self.play(Write(Cell_Captions), run_time=1)
self.play(Create(Outright_Mobject), run_time=10)
# ----- Converting RNN Encoder Cells to Transformer Encoder Cell
self.play(Cells.animate.arrange(RIGHT, buff=0), FadeOut(Cell_Captions))
Cell_Captions.clear_updaters()
self.remove(Cell_Captions)
self.play(FadeTransform(Cells,TransformerEncoderCell))
Cells.clear_updaters()
self.remove(Cells)
Outright_Mobject.clear_updaters()
self.play(
Write(TransformerEncoderText),
Write(TransformerEncoderInputArrow),
Unwrite(Input_Arrows),
Unwrite(Hidden_Arrows),
Unwrite(Words),
Write(TransformerEncoderOutputArrow),
Unwrite(Output_Arrows),
lag_ratio = 0.5)
self.play(FadeTransformPieces(Word_Vectors, TransformerInputMatrix))
self.play(FadeTransformPieces(Hidden_States, TransformerEncoderOutputText))
self.play(FadeTransformPieces(TransformerInputMatrix, TransformerInputMatrix_T))
self.play(Write(VGroup(*TransformerInputMatrix_TRowNames)))
self.play(Write(VGroup(*TransformerInputMatrix_TColNames)))
self.wait(10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment