Last active
December 10, 2020 19:13
-
-
Save flacle/a93ff64b80e85d3ee715d201f5f7b7b6 to your computer and use it in GitHub Desktop.
Manim Gradient Descent Intuition (in Papiamento)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Manim Gradient Descent Video Intuition | |
# Author: Francis Laclé | |
# Video: https://www.youtube.com/watch?v=1cCS6uK_NH8 | |
# Github: https://github.com/flacle | |
# Date: 29 Oct, 2020 | |
from manim import * | |
import math | |
class Intro(Scene): | |
def construct(self): | |
introText = PangoText('Gradient Descent', gradient=(BLUE, GREEN)).scale(2) | |
self.wait(1) | |
self.add(introText) | |
self.play(Write(introText)) | |
self.wait(11) | |
self.play(FadeOut(introText)) | |
self.wait(1) | |
class ThreeDSurface(ParametricSurface): | |
def __init__(self, **kwargs): | |
kwargs = { | |
"u_min": -2, | |
"u_max": 2, | |
"v_min": -2, | |
"v_max": 2, | |
"checkerboard_colors": [BLUE_D] | |
} | |
ParametricSurface.__init__(self, self.func, **kwargs) | |
def func(self, x, y): | |
return np.array([x,y,x**2 - y**2]) | |
class ConYPakicoTraha(ThreeDScene): | |
def construct(self): | |
axes = ThreeDAxes(animate=True) | |
surface = ThreeDSurface() | |
self.set_camera_orientation(phi=75 * DEGREES, theta=30 * DEGREES, distance=30) | |
self.begin_ambient_camera_rotation(rate=0.1) | |
self.wait(1) | |
self.play(ShowCreation(axes)) | |
self.play(ShowCreation(surface)) | |
self.wait(4) | |
self.move_camera(0.4*np.pi/1, -0.45*np.pi) | |
self.wait(4) | |
self.stop_ambient_camera_rotation() | |
self.play(FadeOut(surface)) | |
self.play(FadeOut(axes)) | |
class KostFunctie(Scene): | |
def construct(self): | |
J = Tex(r'$J\left(\cdot\cdot\cdot\right)$').scale(3) | |
Jmin = Tex(r'$\min{J\left(\cdot\cdot\cdot\right)}$').scale(3) | |
self.wait(1) | |
self.play(Write(J)) | |
self.wait(4) | |
self.play(ReplacementTransform(J, Jmin)) | |
self.wait(6) | |
self.play(FadeOut(Jmin)) | |
class OnderzoekAruba(GraphScene): | |
CONFIG = { | |
"y_axis_label": r"Poblacion di Aruba", | |
"x_axis_label": "Aña", | |
"y_max": 7, | |
"y_min": 0, | |
"y_tick_frequency" : 1, | |
"x_max": 9, | |
"x_min": 0, | |
"axes_color" : BLUE | |
} | |
def construct(self): | |
data = [1,1.243735763,1.673120729,2.258542141,2.940774487,3.641230068,4.287015945,4.891799544,5.454441913,6] | |
self.setup_axes() | |
line = self.get_graph(lambda x : (5/9*x)+1, | |
color = RED, | |
x_min = 0, | |
x_max = 9, | |
label="$J(x)$") | |
dot_collection = VGroup() | |
for time, dat in enumerate(data): | |
dot = Dot(color=YELLOW).move_to(self.coords_to_point(time, dat)) | |
dot_collection.add(dot) | |
self.play(FadeIn(dot), rate_func=rush_into) | |
self.wait(1) | |
self.play(ShowCreation(line),run_time = 2) | |
self.wait(1) | |
error_collection = VGroup() | |
for time, dat in enumerate(data): | |
error = Line( | |
self.coords_to_point(time, (5/9*time)+1), dot_collection[time].get_center(), | |
color=GREEN) | |
error_collection.add(error) | |
self.play(ShowCreation(error),run_time = 1) | |
self.wait(2) | |
self.play( | |
FadeOut(error_collection), | |
FadeOut(dot_collection), | |
FadeOut(line), | |
FadeOut(self.axes), | |
FadeOut(self.x_axis_labels), | |
FadeOut(self.y_axis_labels)) | |
self.play() | |
def setup_axes(self): | |
GraphScene.setup_axes(self) | |
self.x_axis.label_direction = UP | |
self.y_axis.label_direction = UP | |
values_x = [ | |
(0,"'09"), | |
(1,"'10"), | |
(2,"'11"), | |
(3,"'12"), | |
(4,"'13"), | |
(5,"'14"), | |
(6,"'15"), | |
(7,"'16"), | |
(8,"'17"), | |
(9,"'18") | |
] | |
values_y = [ | |
(0,"100.000"), | |
(1,"101.000"), | |
(2,"102.000"), | |
(3,"103.000"), | |
(4,"104.000"), | |
(5,"105.000"), | |
(6,"106.000") | |
] | |
self.x_axis_labels = VGroup() | |
self.y_axis_labels = VGroup() | |
# pos. tex. | |
for x_val, x_tex in values_x: | |
tex = PangoText(x_tex).scale(0.6) | |
tex.next_to(self.coords_to_point(x_val, 0), DOWN) #Put tex on the position | |
self.x_axis_labels.add(tex) #Add tex in graph | |
for y_val, y_tex in values_y: | |
tex = PangoText(y_tex).scale(0.6) | |
tex.next_to(self.coords_to_point(0, y_val), LEFT) #Put tex on the position | |
self.y_axis_labels.add(tex) #Add tex in graph | |
self.play( | |
Write(self.x_axis_labels), | |
Write(self.x_axis), | |
Write(self.y_axis_labels), | |
Write(self.y_axis), | |
) | |
class Hypothese(Scene): | |
def construct(self): | |
# 0 , 1 , 2 , 3 , 4 , 5 | |
h = MathTex("h_\\theta\\left(x\\right)","=","\\theta_0","+","{\\theta_1}","x").scale(2) | |
h3= MathTex("h_\\theta\\left(3\\right)","=","\\theta_0","+","{\\theta_1}","3").scale(2) | |
h9= MathTex("9","=","\\theta_0","+","{\\theta_1}","3").scale(2) | |
self.wait(1) | |
self.play(Write(h)) | |
self.wait(6) | |
framebox1 = SurroundingRectangle(h[4], buff = .1) # theta_1 | |
framebox2 = SurroundingRectangle(h[2], buff = .1) # theta_0 | |
framebox3 = SurroundingRectangle(h[0], buff = .1) # left-side | |
self.play( | |
ShowCreation(framebox1), | |
) | |
self.wait(2) | |
self.play( | |
ReplacementTransform(framebox1,framebox2), | |
) | |
self.wait(1) | |
self.play( | |
ReplacementTransform(framebox2,framebox3), | |
) | |
self.wait(3) | |
self.play(FadeOut(framebox3)) | |
self.wait(1) | |
self.play(ReplacementTransform(h, h3)) | |
self.wait(3) | |
self.play(ReplacementTransform(h3, h9)) | |
self.wait(1) | |
self.play(FadeOut(h9)) | |
class CombinacionJmin(Scene): | |
def construct(self): | |
Jmin1 = Tex(r'$\min{J\left(\theta_0, \theta_1)}\to{9}$').scale(2) | |
Jmin2 = Tex(r'$\min{J\left(0, 7)}\to{9}$').scale(2) | |
Jmin3 = Tex(r'$\min{J\left(2, 5)}\to{9}$').scale(2) | |
Jmin4 = Tex(r'$\min{J\left(-3, 1)}\to{9}$').scale(2) | |
Jmin5 = Tex(r'$\min{J\left(-2, 2)}\to{9}$').scale(2) | |
Jmin6 = Tex(r'$\min{J\left(-1, 3)}\to{9}$').scale(2) | |
self.wait(1) | |
self.play(Write(Jmin1)) | |
self.wait(3) | |
self.play(ReplacementTransform(Jmin1, Jmin2)) | |
self.play(ReplacementTransform(Jmin2, Jmin3)) | |
self.play(ReplacementTransform(Jmin3, Jmin4)) | |
self.play(ReplacementTransform(Jmin4, Jmin5)) | |
self.play(ReplacementTransform(Jmin5, Jmin6)) | |
self.wait(8) | |
self.play(FadeOut(Jmin6)) | |
class SomDifferencia(Scene): | |
def construct(self): | |
# 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 | |
Jmin = MathTex("J(\\theta_{0}, \\theta_{1})", "=", "\\frac{1}{2m}", "\\sum\\limits_{i=1}^m", "(", "h_{\\theta}(x^{(i)})", "-", "y^{(i)}", ")^2").scale(1) | |
framebox1 = SurroundingRectangle(Jmin[3], buff = .1) # sum | |
framebox2 = SurroundingRectangle(Jmin[5], buff = .1) # h | |
framebox3 = SurroundingRectangle(Jmin[7], buff = .1) # y | |
self.wait(1) | |
self.play(Write(Jmin)) | |
self.wait(5) | |
self.play( | |
ShowCreation(framebox1), | |
) | |
self.wait(1) | |
self.play( | |
ReplacementTransform(framebox1,framebox2), | |
) | |
self.play( | |
ReplacementTransform(framebox2,framebox3), | |
) | |
self.wait(2) | |
self.play(FadeOut(framebox3)) | |
self.wait(9) | |
self.play(FadeOut(Jmin)) | |
class GradientDescentDilanti(MovingCameraScene): | |
def construct(self): | |
gd = PangoText('Gradient Descent', gradient=(BLUE, GREEN)).scale(2) | |
self.wait(3) | |
self.play(Write(gd)) | |
self.wait(1) | |
self.play(self.camera_frame.set_width, gd.get_width() * 1.2) | |
self.wait(2) | |
self.play(FadeOut(gd)) | |
self.wait(1) | |
Jmin1 = Tex(r'$\min{J\left(\theta_0, \theta_1)}$').scale(1) | |
Jmin2 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2)}$').scale(1) | |
Jmin3 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2, \theta_3)}$').scale(1) | |
Jmin4 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2, \theta_3, \theta_4)}$').scale(1) | |
Jmin5 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2, \theta_3, \theta_4, \theta_5)}$').scale(1) | |
Jmin6 = Tex(r'$\min{J\left(\theta_0, \theta_1)}$').scale(1) | |
self.play(Write(Jmin1)) | |
self.wait(1) | |
self.play(ReplacementTransform(Jmin1, Jmin2)) | |
self.play(ReplacementTransform(Jmin2, Jmin3)) | |
self.play(ReplacementTransform(Jmin3, Jmin4)) | |
self.play(ReplacementTransform(Jmin4, Jmin5)) | |
self.wait(1) | |
self.play(ReplacementTransform(Jmin5, Jmin6)) | |
self.wait(3) | |
self.play(FadeOut(Jmin6)) | |
class KedaRipitiYGana(MovingCameraScene): | |
def construct(self): | |
thetaJ = Tex(r'$\theta_j := $', r'$\theta_j - \alpha \frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right)$').scale(1) | |
simul = Tex(r'(update parew pa $j=0$ y $j=1$ !)').scale(0.75).move_to(2 * DOWN) | |
self.wait(1) | |
self.play(Write(thetaJ), Write(simul)) | |
brace1 = Brace(thetaJ[1], UP, buff=SMALL_BUFF) | |
t1 = brace1.get_text("ripiti te ora e converge") | |
self.play( | |
GrowFromCenter(brace1), | |
FadeIn(t1), | |
) | |
self.wait(9) | |
self.play(FadeOut(t1), FadeOut(brace1), FadeOut(simul)) | |
self.play(self.camera_frame.set_width, thetaJ.get_width() * 1.6) | |
self.play(FadeOut(thetaJ)) | |
class CordaCalculus(GraphScene): | |
CONFIG = { | |
"y_axis_label": r"$y$", | |
"x_axis_label": r"$x$", | |
"y_max": 10, | |
"y_min": 0, | |
"y_tick_frequency" : 1, | |
"x_max": 10, | |
"x_min": 0, | |
"axes_color" : BLUE | |
} | |
def construct(self): | |
self.wait(1) | |
deriv = Tex(r'$\frac{dy}{dx}$').scale(3) | |
self.play(Write(deriv)) | |
self.wait(5) | |
self.play(FadeOut(deriv)) | |
self.setup_axes(animate=True) | |
def graph_to_be_drawn(x): | |
return (x-5)**2 | |
def dx(x): | |
return 2*(x-5) | |
parabola = self.get_graph( | |
lambda x: graph_to_be_drawn(x), | |
x_min=2, | |
x_max=8, | |
color=YELLOW, | |
stroke_opacity=0.5) | |
vt = ValueTracker(0) | |
def moving_dot(): | |
x = vt.get_value() | |
d = Dot(color=WHITE).move_to(self.coords_to_point(x, graph_to_be_drawn(x))) | |
return d | |
md = always_redraw(moving_dot) | |
def get_w_line(): | |
t = TangentLine(md, 1.0, length=2, stroke_opacity=1, color=RED) | |
x = vt.get_value() | |
t.move_to(self.coords_to_point(x, graph_to_be_drawn(x))) | |
# seems to be some rounding error? dx(x) is correct (manshrug) | |
inter = match_interpolate(0.6, -0.6, 3, 7, x) | |
t.rotate(math.atan2(-1,dx(x+inter))) | |
return t | |
vt.set_value(3) | |
line = always_redraw(get_w_line) | |
self.play(ShowCreation(parabola), FadeIn(md), FadeIn(line)) | |
self.wait(1) | |
self.play(vt.set_value, 7, rate_func=there_and_back, run_time=4) | |
self.wait(1) | |
self.play(vt.set_value, 5, rate_func=slow_into, run_time=4) | |
self.wait(6) | |
self.play(vt.set_value, 3, rate_func=slow_into, run_time=1) | |
self.play(vt.set_value, 5, rate_func=slow_into, run_time=6) | |
self.wait(1) | |
self.play(FadeOut(parabola), FadeOut(md), FadeOut(line), FadeOut(self.axes)) | |
class MinTekenMeiMei(MovingCameraScene): | |
def construct(self): | |
thetaJ = Tex(r'$\theta_j := $', r'$\theta_j - \alpha \frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right)$').scale(1) | |
simul = Tex(r'(update parew pa $j=0$ y $j=1$ !)').scale(0.75).move_to(2 * DOWN) | |
brace1 = Brace(thetaJ[1], UP, buff=SMALL_BUFF) | |
framebox1 = SurroundingRectangle(thetaJ[1], buff = .1) # theta_1 | |
t1 = brace1.get_text("ripiti te ora e converge") | |
self.wait(1) | |
self.play( | |
Write(thetaJ), | |
Write(simul), | |
GrowFromCenter(brace1), | |
FadeIn(t1)) | |
self.play(FadeIn(framebox1)) | |
self.wait(3) | |
thetaPlus = Tex(r'$\frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right) > 0 \to$ descent').scale(1) | |
self.play( | |
FadeOut(brace1), | |
FadeOut(t1), | |
FadeOut(simul), | |
FadeOut(framebox1), | |
ReplacementTransform(thetaJ, thetaPlus)) | |
self.wait(4) | |
thetaMin = Tex(r'$\frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right) < 0 \to$ ascent').scale(1) | |
self.play(ReplacementTransform(thetaPlus, thetaMin)) | |
self.wait(4) | |
thetaJ2 = Tex(r'$\theta_j := $', r'$\theta_j - \alpha \frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right)$').scale(1) | |
self.play(ReplacementTransform(thetaMin, thetaJ2)) | |
self.wait(4) | |
self.play(FadeOut(thetaJ2)) | |
class Alpha(GraphScene): | |
CONFIG = { | |
"y_axis_label": r"$y$", | |
"x_axis_label": r"$x$", | |
"y_max": 10, | |
"y_min": 0, | |
"y_tick_frequency" : 1, | |
"x_max": 10, | |
"x_min": 0, | |
"axes_color" : BLUE | |
} | |
def construct(self): | |
self.wait(1) | |
alpha = Tex(r'$\alpha$').scale(3).shift(0) | |
self.play(Write(alpha)) | |
self.play(ApplyMethod(alpha.shift, (UP+RIGHT)*PI)) | |
self.setup_axes(animate=True) | |
def graph_to_be_drawn(x): | |
return (x-5)**2 | |
def dx(x): | |
return 2*(x-5) | |
parabola = self.get_graph( | |
lambda x: graph_to_be_drawn(x), | |
x_min=2, | |
x_max=8, | |
color=YELLOW, | |
stroke_opacity=0.5) | |
vt = ValueTracker(0) | |
def moving_dot(): | |
x = vt.get_value() | |
d = Dot(color=WHITE).move_to(self.coords_to_point(x, graph_to_be_drawn(x))) | |
return d | |
md = always_redraw(moving_dot) | |
def get_w_line(): | |
t = TangentLine(md, 1.0, length=2, stroke_opacity=1, color=RED) | |
x = vt.get_value() | |
t.move_to(self.coords_to_point(x, graph_to_be_drawn(x))) | |
# seems to be some rounding error? dx(x) is correct (manshrug) | |
inter = match_interpolate(0.6, -0.6, 3, 7, x) | |
t.rotate(math.atan2(-1,dx(x+inter))) | |
return t | |
vt.set_value(3) | |
line = always_redraw(get_w_line) | |
self.play( | |
ShowCreation(parabola), | |
FadeIn(md), | |
FadeIn(line), | |
ApplyMethod(alpha.scale, (1/2))) | |
alpha2 = Tex(r'$\alpha = 2.0$') | |
self.play(ReplacementTransform(alpha, alpha2), run_time=0.5) | |
self.play(vt.set_value, 6, rate_func=there_and_back, run_time=2) | |
alpha3 = Tex(r'$\alpha = 1.1$') | |
self.play(ReplacementTransform(alpha2, alpha3), run_time=0.5) | |
self.play(vt.set_value, 4, rate_func=there_and_back, run_time=2) | |
self.play(vt.set_value, 5, rate_func=slow_into, run_time=12) | |
alpha4 = Tex(r'$\alpha = 0.05$') | |
self.play(ReplacementTransform(alpha3, alpha4), run_time=0.5) | |
self.play(vt.set_value, 4, rate_func=there_and_back, run_time=6) | |
alpha5 = Tex(r'$\alpha = 2.2$') | |
self.play(ReplacementTransform(alpha4, alpha5), run_time=0.5) | |
self.play(vt.set_value, 3, rate_func=rush_into, run_time=2) | |
self.play(vt.set_value, 7, rate_func=rush_into, run_time=3) | |
self.wait(3) | |
self.play( | |
FadeOut(parabola), | |
FadeOut(md), | |
FadeOut(line), | |
FadeOut(self.axes), | |
FadeOut(alpha5)) | |
class SaddlePoint(ThreeDScene): | |
def construct(self): | |
axes = ThreeDAxes(animate=True) | |
surface = ThreeDSurface() | |
self.set_camera_orientation(phi=75 * DEGREES, theta=30 * DEGREES, distance=30) | |
self.begin_ambient_camera_rotation(rate=0.1) | |
self.wait(1) | |
self.play(ShowCreation(axes)) | |
self.play(ShowCreation(surface)) | |
self.wait(50) | |
self.play(FadeOut(surface)) | |
self.play(FadeOut(axes)) | |
class TipoDiGradientDescent(Scene): | |
def construct(self): | |
grad1 = Tex(r'SGD').scale(2).shift(0) | |
grad2 = Tex(r'RMSprop').scale(2).shift(0) | |
grad3 = Tex(r'Adam').scale(2).shift(0) | |
grad4 = Tex(r'Adadelta').scale(2).shift(0) | |
grad5 = Tex(r'Adagrad').scale(2).shift(0) | |
grad6 = Tex(r'Adamax').scale(2).shift(0) | |
grad7 = Tex(r'Nadam').scale(2).shift(0) | |
grad8 = Tex(r'Ftrl').scale(2).shift(0) | |
grad9 = Tex(r'BGD').scale(2).shift(0) | |
grad2.next_to(grad1, DOWN*1.5) | |
grad3.next_to(grad1, UP*1.5) | |
grad4.next_to(grad1, LEFT*1.5) | |
grad5.next_to(grad1, RIGHT*1.5) | |
grad6.next_to(grad4, UP*2) | |
grad7.next_to(grad5, UP*2) | |
grad8.next_to(grad5, DOWN*2) | |
grad9.next_to(grad4, DOWN*2) | |
self.wait(1) | |
self.play(Write(grad1)) | |
self.play(Write(grad2)) | |
self.play(Write(grad3)) | |
self.play(Write(grad4)) | |
self.play(Write(grad5)) | |
self.play(Write(grad6)) | |
self.play(Write(grad7)) | |
self.play(Write(grad8)) | |
self.play(Write(grad9)) | |
self.wait(12) | |
self.play( | |
FadeOut(grad9), | |
FadeOut(grad8), | |
FadeOut(grad7), | |
FadeOut(grad6), | |
FadeOut(grad5), | |
FadeOut(grad4), | |
FadeOut(grad3), | |
FadeOut(grad2), | |
FadeOut(grad1) | |
) | |
class Outro(Scene): | |
def construct(self): | |
outroText = PangoText('Masha Danki!', gradient=(BLUE, GREEN)).scale(2) | |
self.wait(1) | |
self.add(outroText) | |
self.play(Write(outroText)) | |
self.wait(3) | |
self.play(FadeOut(outroText)) | |
self.wait(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment