Last active
August 8, 2023 02:09
-
-
Save arif9799/21a537f1a0e493b730c76ccc221b9b7e to your computer and use it in GitHub Desktop.
Attention is not enough! Feed Forward Neural Nets & Recurrent Neural Nets Architecture Animation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from manim import * | |
from manim.utils.unit import Percent, Pixels | |
from colour import Color | |
config.frame_width = 24 | |
config.frame_height = 13.5 | |
HIGH = True | |
if HIGH: | |
print("Its high") | |
config.pixel_height = 2160 | |
config.pixel_width = 3840 | |
else: | |
print("Its low") | |
config.pixel_height = 1080 | |
config.pixel_width = 1920 | |
class Figure2(Scene): | |
def construct(self): | |
################################################################################ | |
# INTRO ANIMATION SAME ACROSS ALL VIDEOS # | |
################################################################################ | |
def myAnimation(wordsForIntro: str): | |
fontHeight = config.frame_height//8 | |
fontColor = WHITE | |
timePerChar = 0.1 | |
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3) | |
self.play(Broadcast(C), run_time=1) | |
self.add(C) | |
# Building text objects of individual characters. | |
wordsMath = VGroup() | |
for word in wordsForIntro: | |
charTex = VGroup() | |
for i,ch in enumerate(word): | |
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight) | |
if i != 0: | |
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN) | |
else: | |
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN) | |
charTex.add(chTex) | |
wordsMath.add(charTex) | |
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION | |
for wInd in range(len(wordsMath)): | |
for chInd in range(len(wordsMath[wInd])): | |
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.wait(0.5) | |
for chInd in reversed(range(len(wordsMath[wInd]))): | |
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//3, run_time=1.5)) | |
self.play(ShrinkToCenter(C, run_time=0.25)) | |
self.wait(0.5) | |
myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE']) | |
################################################################################ | |
# MAIN CODE OF THE FIGURE STARTS HERE # | |
# NO FUNCTION OR ANYTHING OF THAT SORT # | |
################################################################################ | |
############################################################################################################################################################## | |
# |------ |------- |\ | |\ | # | |
# | | | \ | | \ | # | |
# |------ |------ | \ | | \ | # | |
# | | | \ | | \ | # | |
# | | | \| | \| # | |
############################################################################################################################################################## | |
# defining variables and assigning them values | |
neuron_radius = 0.1 | |
input_neurons = 5 | |
hidden_neurons = 12 | |
output_neurons = 2 | |
hidden_layers = 4 | |
StrokeWidth = 1 | |
layers = [] | |
weights = [] | |
# ------------------------------------------------------------------------------------------------------------------------------------------------------------ | |
# constructing input layer of NN architecture | |
input_layer = VGroup(*[ | |
Circle( | |
radius= neuron_radius, | |
color= BLUE_C, | |
fill_opacity= 0) | |
for i in range(input_neurons) | |
] | |
) | |
input_layer.arrange(DOWN, buff=0.1) | |
layers.append(input_layer) | |
# constructing arrows pointing to input layer | |
input_layer_arrows = VGroup(*[ | |
always_redraw( | |
lambda i=i: | |
Arrow( | |
start= input_layer[i].get_left()- [0.5,0,0], | |
end = input_layer[i].get_left(), | |
buff=0, | |
stroke_width=StrokeWidth, | |
max_tip_length_to_length_ratio=0.1) | |
) | |
for i in range(input_neurons) | |
] | |
) | |
weights.append(input_layer_arrows) | |
# ------------------------------------------------------------------------------------------------------------------------------------------------------------ | |
# constructing a dynamic function to build hidden layers | |
def constructHiddenLayers(hid_n, previous_layer): | |
neuron_layer = VGroup(*[ | |
Circle( | |
radius= neuron_radius, | |
color= BLUE_C, | |
fill_opacity= 0.2) | |
for i in range(hid_n) | |
] | |
).arrange(DOWN, buff=0.1).next_to(previous_layer, RIGHT*4) | |
weights = VGroup(*[ | |
Arrow( | |
start = p.get_right(), | |
end = n.get_left(), | |
buff = 0, | |
stroke_width=0.75, | |
max_tip_length_to_length_ratio=0 | |
).add_updater( | |
lambda mob, p=p, n=n: mob.become( | |
Arrow( | |
start = p.get_right(), | |
end = n.get_left(), | |
buff = 0, | |
stroke_width=StrokeWidth, | |
max_tip_length_to_length_ratio=0 | |
) | |
) | |
) for n in neuron_layer for p in previous_layer | |
] | |
) | |
return neuron_layer, weights | |
for i in range(hidden_layers): | |
hidden_layer, weight = constructHiddenLayers(hid_n= hidden_neurons, previous_layer= layers[i]) | |
layers.append(hidden_layer) | |
weights.append(weight) | |
output_layer, weight = constructHiddenLayers(output_neurons, layers[-1]) | |
layers.append(output_layer) | |
weights.append(weight) | |
# ------------------------------------------------------------------------------------------------------------------------------------------------------------ | |
# constructing arrows pointing out of output layer neurons | |
output_layer_arrows = VGroup(*[ | |
always_redraw( | |
lambda i=i: | |
Arrow( | |
start= output_layer[i].get_right(), | |
end= output_layer[i].get_right()+ [0.5,0,0], | |
buff=0, | |
stroke_width=StrokeWidth, | |
max_tip_length_to_length_ratio=0.1) | |
) | |
for i in range(output_neurons) | |
] | |
) | |
weights.append(output_layer_arrows) | |
for layer, weight in zip(layers, weights): | |
self.play(DrawBorderThenFill(layer), run_time = 1) | |
self.play(Write(weight), run_time = 1) | |
self.play(Write(weights[-1])) | |
ffnn_all_weights = VGroup() | |
ffnn_all_neurons = VGroup(*[l for layer in layers for l in layer]) | |
for weight in weights: | |
ffnn_all_weights+= weight | |
self.play(ShowPassingFlash(ffnn_all_weights.copy().set_color(BLUE_C), time_width = 1)) | |
self.play(ffnn_all_neurons.animate.to_edge(LEFT*6 + UP*3)) | |
self.play(Write(Text(f"FEED FORWARD NEURAL NETWORK (FFNN)", font_size = 24).next_to(ffnn_all_neurons, DOWN*2))) | |
self.wait(0.5) | |
############################################################################################################################################################## | |
# |-----| |\ | |\ | # | |
# | | | \ | | \ | # | |
# |-----| | \ | | \ | # | |
# |\ | \ | | \ | # | |
# | \ | \| | \| # | |
############################################################################################################################################################## | |
enc_cell_width = 1 | |
enc_cell_height = 0.45 | |
strokeWidth = 3 | |
# constructing Recurrent Neural Network -- Single | |
enc_cell = Rectangle(height= enc_cell_height, width = enc_cell_width, color = BLUE_C, fill_opacity = 0.25) | |
feedback_loop_arr_1 = Arrow().add_updater( | |
lambda mob: mob.become( | |
Arrow(start= enc_cell.get_left(), end= enc_cell.get_left() + [-enc_cell_height,0,0], buff=0, max_tip_length_to_length_ratio=0, stroke_width=strokeWidth) | |
) | |
) | |
feedback_loop_arr_2 = Arrow().add_updater( | |
lambda mob: mob.become( | |
Arrow(start= feedback_loop_arr_1.get_left(), end= feedback_loop_arr_1.get_left() + [0,- enc_cell_height - 0.25,0], buff=0, max_tip_length_to_length_ratio=0, stroke_width=strokeWidth) | |
) | |
) | |
feedback_loop_arr_3 = Arrow().add_updater( | |
lambda mob: mob.become( | |
Arrow(start= feedback_loop_arr_2.get_bottom(), end= feedback_loop_arr_2.get_bottom() + [2*enc_cell_height +enc_cell.length_over_dim(0),0,0], buff=0, max_tip_length_to_length_ratio=0, stroke_width=strokeWidth) | |
) | |
) | |
feedback_loop_arr_4 = Arrow().add_updater( | |
lambda mob: mob.become( | |
Arrow(start= feedback_loop_arr_3.get_right(), end= feedback_loop_arr_3.get_right() + [0, enc_cell_height + 0.25,0], buff=0, max_tip_length_to_length_ratio=0, stroke_width=strokeWidth) | |
) | |
) | |
feedback_loop_arr_5 = Arrow().add_updater( | |
lambda mob: mob.become( | |
Arrow(start= feedback_loop_arr_4.get_top(), end= feedback_loop_arr_4.get_top() + [-enc_cell_height,0,0], buff=0, max_tip_length_to_length_ratio=0.25, stroke_width=strokeWidth) | |
) | |
) | |
input_arr = Arrow().add_updater( | |
lambda mob: mob.become( | |
Arrow(start = enc_cell.get_bottom()+ [0,-1.5,0], end = enc_cell.get_bottom(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth) | |
) | |
) | |
output_arr = Arrow().add_updater( | |
lambda mob: mob.become( | |
Arrow(start = enc_cell.get_top(), end = enc_cell.get_top()+ [0,1.5,0], buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth) | |
) | |
) | |
feedback_loop = VGroup(*[ | |
feedback_loop_arr_1, | |
feedback_loop_arr_2, | |
feedback_loop_arr_3, | |
feedback_loop_arr_4, | |
feedback_loop_arr_5 | |
] | |
) | |
inp = ['input_{t1}','input_{t2}','input_{t3}'] | |
out = ['output_{t1}','output_{t2}','output_{t3}'] | |
inp_text = VGroup(*[ MathTex(i,color = WHITE, font_size = 32).next_to(input_arr, DOWN).add_updater(lambda x: x.next_to(input_arr, DOWN)) for i in inp]) | |
out_text = VGroup(*[ MathTex(o,color = WHITE, font_size = 32).next_to(output_arr, UP).add_updater(lambda x: x.next_to(output_arr, UP)) for o in out]) | |
self.play(DrawBorderThenFill(enc_cell), run_time = 1) | |
self.play(Write(feedback_loop), run_time = 1) | |
self.play(Write(input_arr), run_time = 1) | |
self.play(Write(output_arr), run_time = 1) | |
for i,o in zip(inp_text,out_text): | |
self.play(Write(i), run_time = 0.5) | |
self.play(ShowPassingFlash(input_arr.copy().set_color(BLACK),time_width=0.5), ShowPassingFlash(feedback_loop.copy().set_color(BLACK),time_width=0.5), run_time = 0.5) | |
self.play(Indicate(enc_cell), run_time = 0.5) | |
self.play(ShowPassingFlash(output_arr.copy().set_color(BLACK),time_width=0.5), run_time = 0.5) | |
self.play(Unwrite(i),Write(o), run_time = 0.5) | |
self.play(Unwrite(o), run_time = 0.5) | |
self.play(Write(MathTex('Input_{i}',color = WHITE, font_size = 32).next_to(input_arr, DOWN).add_updater(lambda x: x.next_to(input_arr, DOWN))), run_time = 0.5) | |
self.play(Write(MathTex('Output_{i}',color = WHITE, font_size = 32).next_to(output_arr, UP).add_updater(lambda x: x.next_to(output_arr, UP))), run_time = 0.5) | |
self.play(enc_cell.animate.to_edge(RIGHT*8 + UP*7), run_time = 1) | |
self.play(enc_cell.animate.scale(0.75), run_time = 1) | |
self.play(Write(Text(f"RECCURENT NEURAL NETWORKS (RNNs)", font_size = 24).next_to(input_arr, DOWN*3))) | |
self.wait(0.5) | |
# constructing Recurrent Neural Network -- Unfolded in time | |
enc_cell_width = 0.9 | |
enc_cell_height = 0.4 | |
strokeWidth = 2 | |
input_words = ['Attention', 'is', 'not', 'enough'] | |
output_words = ['<start>', 'dikkat', 'yegerli', 'yetil','<end>'] | |
input_hidden_states = ['h_0','h_1', 'h_2', 'h_3' ] | |
output_hidden_states = ['s_0', 's_1', 's_2', 's_3' ] | |
Cells = VGroup() | |
Hidden_Arrows = VGroup() | |
Input_Arrows = VGroup() | |
Output_Arrows = VGroup() | |
Intermediate_Arrows = VGroup() | |
Input_Hidden_States = VGroup() | |
Output_Hidden_States = VGroup() | |
Words = VGroup() | |
Outright_Mobject = VGroup() | |
for w,h in zip(input_words, input_hidden_states): | |
Cells.add(Rectangle(height= enc_cell_height, width = enc_cell_width, color = BLUE_C, fill_opacity = 0.25)) | |
if len(Cells) == 1: | |
Hidden_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater( | |
lambda mob, enc = Cells[-1]: mob.become(Arrow(start = enc.get_left() + [-1,0,0], end = enc.get_left(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)) | |
)) | |
Input_Hidden_States.add(MathTex(h,color = WHITE, font_size = 32).next_to(Hidden_Arrows[-1], LEFT/4).add_updater(lambda x, y = Hidden_Arrows[-1]: x.next_to(y, LEFT/4))) | |
else: | |
Hidden_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater( | |
lambda mob, enc = Cells[-1], prev_enc = Cells[-2]: mob.become(Arrow(start = prev_enc.get_right(), end = enc.get_left(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)) | |
)) | |
Input_Hidden_States.add( MathTex(h,color = WHITE, font_size = 32).next_to(Hidden_Arrows[-1], UP/4).add_updater(lambda x, y = Hidden_Arrows[-1]: x.next_to(y, UP/4))) | |
Input_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater( | |
lambda mob, enc = Cells[-1]: mob.become(Arrow(start = enc.get_bottom() + [0,-1,0], end = enc.get_bottom(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)) | |
)) | |
Words.add(Text(w,color = WHITE, font_size = 32, slant = ITALIC).next_to(Input_Arrows[-1], DOWN/4).add_updater(lambda x, y = Input_Arrows[-1]: x.next_to(y, DOWN/4))) | |
Output_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater( | |
lambda mob, enc = Cells[-1]: mob.become(Arrow(start = enc.get_top(), end = enc.get_top() + [0,1,0], buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)) | |
)) | |
Outright_Mobject.add(Hidden_Arrows[-1]) | |
Outright_Mobject.add(Input_Hidden_States[-1]) | |
Outright_Mobject.add(Input_Arrows[-1]) | |
Outright_Mobject.add(Words[-1]) | |
Outright_Mobject.add(Output_Arrows[-1]) | |
for prior_w, post_w, s in zip(output_words, output_words[1:], output_hidden_states): | |
Cells.add(Rectangle(height= enc_cell_height, width = enc_cell_width, color = GREEN_C, fill_opacity = 0.25)) | |
Hidden_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater( | |
lambda mob, enc = Cells[-1], prev_enc = Cells[-2]: mob.become(Arrow(start = prev_enc.get_right(), end = enc.get_left(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)) | |
)) | |
Output_Hidden_States.add( MathTex(s,color = WHITE, font_size = 32).next_to(Hidden_Arrows[-1], UP/4).add_updater(lambda x, y = Hidden_Arrows[-1]: x.next_to(y, UP/4))) | |
Input_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater( | |
lambda mob, enc = Cells[-1]: mob.become(Arrow(start = enc.get_bottom() + [0,-1,0], end = enc.get_bottom(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)) | |
)) | |
Words.add(Text(prior_w,color = WHITE, font_size = 32, slant = ITALIC).next_to(Input_Arrows[-1], DOWN/4).add_updater(lambda x, y = Input_Arrows[-1]: x.next_to(y, DOWN/4))) | |
Outright_Mobject.add(Hidden_Arrows[-1]) | |
Outright_Mobject.add(Output_Hidden_States[-1]) | |
Outright_Mobject.add(Input_Arrows[-1]) | |
Outright_Mobject.add(Words[-1]) | |
# if len(Cells) > (len(input_words) + 1): | |
# Intermediate_Arrows.add(CurvedArrow(start_point = Words[-2].get_top(), end_point = Words[-1].get_left(), stroke_width=0.25).add_updater( | |
# lambda mob, s = Words[-2], e = Words[-1]: mob.become(CurvedArrow(start_point = s.get_top(), end_point = e.get_left(), stroke_width=strokeWidth)) | |
# )) | |
# Outright_Mobject.add(Intermediate_Arrows[-1]) | |
Output_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater( | |
lambda mob, enc = Cells[-1]: mob.become(Arrow(start = enc.get_top(), end = enc.get_top() + [0,1,0], buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)) | |
)) | |
Words.add(Text(post_w,color = WHITE, font_size = 32, slant = ITALIC).next_to(Output_Arrows[-1], UP/4).add_updater(lambda x, y = Output_Arrows[-1]: x.next_to(y, UP/4))) | |
Outright_Mobject.add(Output_Arrows[-1]) | |
Outright_Mobject.add(Words[-1]) | |
Cells.arrange(RIGHT, buff=1).shift(DOWN) | |
self.play(Create(Cells), run_time=2) | |
self.play(Cells.animate.shift(DOWN*2)) | |
self.play(Create(Outright_Mobject), run_time=8) | |
self.play(Write(Text(f"RECURRENT NEURAL NETWORKS (RNNS) Unfolded in time!", font_size = 24).next_to(Outright_Mobject, DOWN*2))) | |
self.wait(10) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment