Skip to content

Instantly share code, notes, and snippets.

@arif9799
Last active August 8, 2023 02:09
Show Gist options
  • Save arif9799/21a537f1a0e493b730c76ccc221b9b7e to your computer and use it in GitHub Desktop.
Save arif9799/21a537f1a0e493b730c76ccc221b9b7e to your computer and use it in GitHub Desktop.
Attention is not enough! Feed Forward Neural Nets & Recurrent Neural Nets Architecture Animation
from manim import *
from manim.utils.unit import Percent, Pixels
from colour import Color
config.frame_width = 24
config.frame_height = 13.5
HIGH = True
if HIGH:
print("Its high")
config.pixel_height = 2160
config.pixel_width = 3840
else:
print("Its low")
config.pixel_height = 1080
config.pixel_width = 1920
class Figure2(Scene):
def construct(self):
################################################################################
# INTRO ANIMATION SAME ACROSS ALL VIDEOS #
################################################################################
def myAnimation(wordsForIntro: str):
fontHeight = config.frame_height//8
fontColor = WHITE
timePerChar = 0.1
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3)
self.play(Broadcast(C), run_time=1)
self.add(C)
# Building text objects of individual characters.
wordsMath = VGroup()
for word in wordsForIntro:
charTex = VGroup()
for i,ch in enumerate(word):
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight)
if i != 0:
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN)
else:
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN)
charTex.add(chTex)
wordsMath.add(charTex)
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION
for wInd in range(len(wordsMath)):
for chInd in range(len(wordsMath[wInd])):
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar))
self.wait(0.5)
for chInd in reversed(range(len(wordsMath[wInd]))):
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar))
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//3, run_time=1.5))
self.play(ShrinkToCenter(C, run_time=0.25))
self.wait(0.5)
myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE'])
################################################################################
# MAIN CODE OF THE FIGURE STARTS HERE #
# NO FUNCTION OR ANYTHING OF THAT SORT #
################################################################################
##############################################################################################################################################################
# |------ |------- |\ | |\ | #
# | | | \ | | \ | #
# |------ |------ | \ | | \ | #
# | | | \ | | \ | #
# | | | \| | \| #
##############################################################################################################################################################
# defining variables and assigning them values
neuron_radius = 0.1
input_neurons = 5
hidden_neurons = 12
output_neurons = 2
hidden_layers = 4
StrokeWidth = 1
layers = []
weights = []
# ------------------------------------------------------------------------------------------------------------------------------------------------------------
# constructing input layer of NN architecture
input_layer = VGroup(*[
Circle(
radius= neuron_radius,
color= BLUE_C,
fill_opacity= 0)
for i in range(input_neurons)
]
)
input_layer.arrange(DOWN, buff=0.1)
layers.append(input_layer)
# constructing arrows pointing to input layer
input_layer_arrows = VGroup(*[
always_redraw(
lambda i=i:
Arrow(
start= input_layer[i].get_left()- [0.5,0,0],
end = input_layer[i].get_left(),
buff=0,
stroke_width=StrokeWidth,
max_tip_length_to_length_ratio=0.1)
)
for i in range(input_neurons)
]
)
weights.append(input_layer_arrows)
# ------------------------------------------------------------------------------------------------------------------------------------------------------------
# constructing a dynamic function to build hidden layers
def constructHiddenLayers(hid_n, previous_layer):
neuron_layer = VGroup(*[
Circle(
radius= neuron_radius,
color= BLUE_C,
fill_opacity= 0.2)
for i in range(hid_n)
]
).arrange(DOWN, buff=0.1).next_to(previous_layer, RIGHT*4)
weights = VGroup(*[
Arrow(
start = p.get_right(),
end = n.get_left(),
buff = 0,
stroke_width=0.75,
max_tip_length_to_length_ratio=0
).add_updater(
lambda mob, p=p, n=n: mob.become(
Arrow(
start = p.get_right(),
end = n.get_left(),
buff = 0,
stroke_width=StrokeWidth,
max_tip_length_to_length_ratio=0
)
)
) for n in neuron_layer for p in previous_layer
]
)
return neuron_layer, weights
for i in range(hidden_layers):
hidden_layer, weight = constructHiddenLayers(hid_n= hidden_neurons, previous_layer= layers[i])
layers.append(hidden_layer)
weights.append(weight)
output_layer, weight = constructHiddenLayers(output_neurons, layers[-1])
layers.append(output_layer)
weights.append(weight)
# ------------------------------------------------------------------------------------------------------------------------------------------------------------
# constructing arrows pointing out of output layer neurons
output_layer_arrows = VGroup(*[
always_redraw(
lambda i=i:
Arrow(
start= output_layer[i].get_right(),
end= output_layer[i].get_right()+ [0.5,0,0],
buff=0,
stroke_width=StrokeWidth,
max_tip_length_to_length_ratio=0.1)
)
for i in range(output_neurons)
]
)
weights.append(output_layer_arrows)
for layer, weight in zip(layers, weights):
self.play(DrawBorderThenFill(layer), run_time = 1)
self.play(Write(weight), run_time = 1)
self.play(Write(weights[-1]))
ffnn_all_weights = VGroup()
ffnn_all_neurons = VGroup(*[l for layer in layers for l in layer])
for weight in weights:
ffnn_all_weights+= weight
self.play(ShowPassingFlash(ffnn_all_weights.copy().set_color(BLUE_C), time_width = 1))
self.play(ffnn_all_neurons.animate.to_edge(LEFT*6 + UP*3))
self.play(Write(Text(f"FEED FORWARD NEURAL NETWORK (FFNN)", font_size = 24).next_to(ffnn_all_neurons, DOWN*2)))
self.wait(0.5)
##############################################################################################################################################################
# |-----| |\ | |\ | #
# | | | \ | | \ | #
# |-----| | \ | | \ | #
# |\ | \ | | \ | #
# | \ | \| | \| #
##############################################################################################################################################################
enc_cell_width = 1
enc_cell_height = 0.45
strokeWidth = 3
# constructing Recurrent Neural Network -- Single
enc_cell = Rectangle(height= enc_cell_height, width = enc_cell_width, color = BLUE_C, fill_opacity = 0.25)
feedback_loop_arr_1 = Arrow().add_updater(
lambda mob: mob.become(
Arrow(start= enc_cell.get_left(), end= enc_cell.get_left() + [-enc_cell_height,0,0], buff=0, max_tip_length_to_length_ratio=0, stroke_width=strokeWidth)
)
)
feedback_loop_arr_2 = Arrow().add_updater(
lambda mob: mob.become(
Arrow(start= feedback_loop_arr_1.get_left(), end= feedback_loop_arr_1.get_left() + [0,- enc_cell_height - 0.25,0], buff=0, max_tip_length_to_length_ratio=0, stroke_width=strokeWidth)
)
)
feedback_loop_arr_3 = Arrow().add_updater(
lambda mob: mob.become(
Arrow(start= feedback_loop_arr_2.get_bottom(), end= feedback_loop_arr_2.get_bottom() + [2*enc_cell_height +enc_cell.length_over_dim(0),0,0], buff=0, max_tip_length_to_length_ratio=0, stroke_width=strokeWidth)
)
)
feedback_loop_arr_4 = Arrow().add_updater(
lambda mob: mob.become(
Arrow(start= feedback_loop_arr_3.get_right(), end= feedback_loop_arr_3.get_right() + [0, enc_cell_height + 0.25,0], buff=0, max_tip_length_to_length_ratio=0, stroke_width=strokeWidth)
)
)
feedback_loop_arr_5 = Arrow().add_updater(
lambda mob: mob.become(
Arrow(start= feedback_loop_arr_4.get_top(), end= feedback_loop_arr_4.get_top() + [-enc_cell_height,0,0], buff=0, max_tip_length_to_length_ratio=0.25, stroke_width=strokeWidth)
)
)
input_arr = Arrow().add_updater(
lambda mob: mob.become(
Arrow(start = enc_cell.get_bottom()+ [0,-1.5,0], end = enc_cell.get_bottom(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)
)
)
output_arr = Arrow().add_updater(
lambda mob: mob.become(
Arrow(start = enc_cell.get_top(), end = enc_cell.get_top()+ [0,1.5,0], buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth)
)
)
feedback_loop = VGroup(*[
feedback_loop_arr_1,
feedback_loop_arr_2,
feedback_loop_arr_3,
feedback_loop_arr_4,
feedback_loop_arr_5
]
)
inp = ['input_{t1}','input_{t2}','input_{t3}']
out = ['output_{t1}','output_{t2}','output_{t3}']
inp_text = VGroup(*[ MathTex(i,color = WHITE, font_size = 32).next_to(input_arr, DOWN).add_updater(lambda x: x.next_to(input_arr, DOWN)) for i in inp])
out_text = VGroup(*[ MathTex(o,color = WHITE, font_size = 32).next_to(output_arr, UP).add_updater(lambda x: x.next_to(output_arr, UP)) for o in out])
self.play(DrawBorderThenFill(enc_cell), run_time = 1)
self.play(Write(feedback_loop), run_time = 1)
self.play(Write(input_arr), run_time = 1)
self.play(Write(output_arr), run_time = 1)
for i,o in zip(inp_text,out_text):
self.play(Write(i), run_time = 0.5)
self.play(ShowPassingFlash(input_arr.copy().set_color(BLACK),time_width=0.5), ShowPassingFlash(feedback_loop.copy().set_color(BLACK),time_width=0.5), run_time = 0.5)
self.play(Indicate(enc_cell), run_time = 0.5)
self.play(ShowPassingFlash(output_arr.copy().set_color(BLACK),time_width=0.5), run_time = 0.5)
self.play(Unwrite(i),Write(o), run_time = 0.5)
self.play(Unwrite(o), run_time = 0.5)
self.play(Write(MathTex('Input_{i}',color = WHITE, font_size = 32).next_to(input_arr, DOWN).add_updater(lambda x: x.next_to(input_arr, DOWN))), run_time = 0.5)
self.play(Write(MathTex('Output_{i}',color = WHITE, font_size = 32).next_to(output_arr, UP).add_updater(lambda x: x.next_to(output_arr, UP))), run_time = 0.5)
self.play(enc_cell.animate.to_edge(RIGHT*8 + UP*7), run_time = 1)
self.play(enc_cell.animate.scale(0.75), run_time = 1)
self.play(Write(Text(f"RECCURENT NEURAL NETWORKS (RNNs)", font_size = 24).next_to(input_arr, DOWN*3)))
self.wait(0.5)
# constructing Recurrent Neural Network -- Unfolded in time
enc_cell_width = 0.9
enc_cell_height = 0.4
strokeWidth = 2
input_words = ['Attention', 'is', 'not', 'enough']
output_words = ['<start>', 'dikkat', 'yegerli', 'yetil','<end>']
input_hidden_states = ['h_0','h_1', 'h_2', 'h_3' ]
output_hidden_states = ['s_0', 's_1', 's_2', 's_3' ]
Cells = VGroup()
Hidden_Arrows = VGroup()
Input_Arrows = VGroup()
Output_Arrows = VGroup()
Intermediate_Arrows = VGroup()
Input_Hidden_States = VGroup()
Output_Hidden_States = VGroup()
Words = VGroup()
Outright_Mobject = VGroup()
for w,h in zip(input_words, input_hidden_states):
Cells.add(Rectangle(height= enc_cell_height, width = enc_cell_width, color = BLUE_C, fill_opacity = 0.25))
if len(Cells) == 1:
Hidden_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater(
lambda mob, enc = Cells[-1]: mob.become(Arrow(start = enc.get_left() + [-1,0,0], end = enc.get_left(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))
))
Input_Hidden_States.add(MathTex(h,color = WHITE, font_size = 32).next_to(Hidden_Arrows[-1], LEFT/4).add_updater(lambda x, y = Hidden_Arrows[-1]: x.next_to(y, LEFT/4)))
else:
Hidden_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater(
lambda mob, enc = Cells[-1], prev_enc = Cells[-2]: mob.become(Arrow(start = prev_enc.get_right(), end = enc.get_left(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))
))
Input_Hidden_States.add( MathTex(h,color = WHITE, font_size = 32).next_to(Hidden_Arrows[-1], UP/4).add_updater(lambda x, y = Hidden_Arrows[-1]: x.next_to(y, UP/4)))
Input_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater(
lambda mob, enc = Cells[-1]: mob.become(Arrow(start = enc.get_bottom() + [0,-1,0], end = enc.get_bottom(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))
))
Words.add(Text(w,color = WHITE, font_size = 32, slant = ITALIC).next_to(Input_Arrows[-1], DOWN/4).add_updater(lambda x, y = Input_Arrows[-1]: x.next_to(y, DOWN/4)))
Output_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater(
lambda mob, enc = Cells[-1]: mob.become(Arrow(start = enc.get_top(), end = enc.get_top() + [0,1,0], buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))
))
Outright_Mobject.add(Hidden_Arrows[-1])
Outright_Mobject.add(Input_Hidden_States[-1])
Outright_Mobject.add(Input_Arrows[-1])
Outright_Mobject.add(Words[-1])
Outright_Mobject.add(Output_Arrows[-1])
for prior_w, post_w, s in zip(output_words, output_words[1:], output_hidden_states):
Cells.add(Rectangle(height= enc_cell_height, width = enc_cell_width, color = GREEN_C, fill_opacity = 0.25))
Hidden_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater(
lambda mob, enc = Cells[-1], prev_enc = Cells[-2]: mob.become(Arrow(start = prev_enc.get_right(), end = enc.get_left(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))
))
Output_Hidden_States.add( MathTex(s,color = WHITE, font_size = 32).next_to(Hidden_Arrows[-1], UP/4).add_updater(lambda x, y = Hidden_Arrows[-1]: x.next_to(y, UP/4)))
Input_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater(
lambda mob, enc = Cells[-1]: mob.become(Arrow(start = enc.get_bottom() + [0,-1,0], end = enc.get_bottom(), buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))
))
Words.add(Text(prior_w,color = WHITE, font_size = 32, slant = ITALIC).next_to(Input_Arrows[-1], DOWN/4).add_updater(lambda x, y = Input_Arrows[-1]: x.next_to(y, DOWN/4)))
Outright_Mobject.add(Hidden_Arrows[-1])
Outright_Mobject.add(Output_Hidden_States[-1])
Outright_Mobject.add(Input_Arrows[-1])
Outright_Mobject.add(Words[-1])
# if len(Cells) > (len(input_words) + 1):
# Intermediate_Arrows.add(CurvedArrow(start_point = Words[-2].get_top(), end_point = Words[-1].get_left(), stroke_width=0.25).add_updater(
# lambda mob, s = Words[-2], e = Words[-1]: mob.become(CurvedArrow(start_point = s.get_top(), end_point = e.get_left(), stroke_width=strokeWidth))
# ))
# Outright_Mobject.add(Intermediate_Arrows[-1])
Output_Arrows.add(Arrow(stroke_width=strokeWidth).add_updater(
lambda mob, enc = Cells[-1]: mob.become(Arrow(start = enc.get_top(), end = enc.get_top() + [0,1,0], buff=0, max_tip_length_to_length_ratio=0.1, stroke_width=strokeWidth))
))
Words.add(Text(post_w,color = WHITE, font_size = 32, slant = ITALIC).next_to(Output_Arrows[-1], UP/4).add_updater(lambda x, y = Output_Arrows[-1]: x.next_to(y, UP/4)))
Outright_Mobject.add(Output_Arrows[-1])
Outright_Mobject.add(Words[-1])
Cells.arrange(RIGHT, buff=1).shift(DOWN)
self.play(Create(Cells), run_time=2)
self.play(Cells.animate.shift(DOWN*2))
self.play(Create(Outright_Mobject), run_time=8)
self.play(Write(Text(f"RECURRENT NEURAL NETWORKS (RNNS) Unfolded in time!", font_size = 24).next_to(Outright_Mobject, DOWN*2)))
self.wait(10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment