# Manim Gradient Descent Video Intuition # Author: Francis Laclé # Video: https://www.youtube.com/watch?v=1cCS6uK_NH8 # Github: https://github.com/flacle # Date: 29 Oct, 2020 from manim import * import math class Intro(Scene): def construct(self): introText = PangoText('Gradient Descent', gradient=(BLUE, GREEN)).scale(2) self.wait(1) self.add(introText) self.play(Write(introText)) self.wait(11) self.play(FadeOut(introText)) self.wait(1) class ThreeDSurface(ParametricSurface): def __init__(self, **kwargs): kwargs = { "u_min": -2, "u_max": 2, "v_min": -2, "v_max": 2, "checkerboard_colors": [BLUE_D] } ParametricSurface.__init__(self, self.func, **kwargs) def func(self, x, y): return np.array([x,y,x**2 - y**2]) class ConYPakicoTraha(ThreeDScene): def construct(self): axes = ThreeDAxes(animate=True) surface = ThreeDSurface() self.set_camera_orientation(phi=75 * DEGREES, theta=30 * DEGREES, distance=30) self.begin_ambient_camera_rotation(rate=0.1) self.wait(1) self.play(ShowCreation(axes)) self.play(ShowCreation(surface)) self.wait(4) self.move_camera(0.4*np.pi/1, -0.45*np.pi) self.wait(4) self.stop_ambient_camera_rotation() self.play(FadeOut(surface)) self.play(FadeOut(axes)) class KostFunctie(Scene): def construct(self): J = Tex(r'$J\left(\cdot\cdot\cdot\right)$').scale(3) Jmin = Tex(r'$\min{J\left(\cdot\cdot\cdot\right)}$').scale(3) self.wait(1) self.play(Write(J)) self.wait(4) self.play(ReplacementTransform(J, Jmin)) self.wait(6) self.play(FadeOut(Jmin)) class OnderzoekAruba(GraphScene): CONFIG = { "y_axis_label": r"Poblacion di Aruba", "x_axis_label": "Aña", "y_max": 7, "y_min": 0, "y_tick_frequency" : 1, "x_max": 9, "x_min": 0, "axes_color" : BLUE } def construct(self): data = [1,1.243735763,1.673120729,2.258542141,2.940774487,3.641230068,4.287015945,4.891799544,5.454441913,6] self.setup_axes() line = self.get_graph(lambda x : (5/9*x)+1, color = RED, x_min = 0, x_max = 9, label="$J(x)$") dot_collection = VGroup() for time, dat in enumerate(data): dot = Dot(color=YELLOW).move_to(self.coords_to_point(time, dat)) dot_collection.add(dot) self.play(FadeIn(dot), rate_func=rush_into) self.wait(1) self.play(ShowCreation(line),run_time = 2) self.wait(1) error_collection = VGroup() for time, dat in enumerate(data): error = Line( self.coords_to_point(time, (5/9*time)+1), dot_collection[time].get_center(), color=GREEN) error_collection.add(error) self.play(ShowCreation(error),run_time = 1) self.wait(2) self.play( FadeOut(error_collection), FadeOut(dot_collection), FadeOut(line), FadeOut(self.axes), FadeOut(self.x_axis_labels), FadeOut(self.y_axis_labels)) self.play() def setup_axes(self): GraphScene.setup_axes(self) self.x_axis.label_direction = UP self.y_axis.label_direction = UP values_x = [ (0,"'09"), (1,"'10"), (2,"'11"), (3,"'12"), (4,"'13"), (5,"'14"), (6,"'15"), (7,"'16"), (8,"'17"), (9,"'18") ] values_y = [ (0,"100.000"), (1,"101.000"), (2,"102.000"), (3,"103.000"), (4,"104.000"), (5,"105.000"), (6,"106.000") ] self.x_axis_labels = VGroup() self.y_axis_labels = VGroup() # pos. tex. for x_val, x_tex in values_x: tex = PangoText(x_tex).scale(0.6) tex.next_to(self.coords_to_point(x_val, 0), DOWN) #Put tex on the position self.x_axis_labels.add(tex) #Add tex in graph for y_val, y_tex in values_y: tex = PangoText(y_tex).scale(0.6) tex.next_to(self.coords_to_point(0, y_val), LEFT) #Put tex on the position self.y_axis_labels.add(tex) #Add tex in graph self.play( Write(self.x_axis_labels), Write(self.x_axis), Write(self.y_axis_labels), Write(self.y_axis), ) class Hypothese(Scene): def construct(self): # 0 , 1 , 2 , 3 , 4 , 5 h = MathTex("h_\\theta\\left(x\\right)","=","\\theta_0","+","{\\theta_1}","x").scale(2) h3= MathTex("h_\\theta\\left(3\\right)","=","\\theta_0","+","{\\theta_1}","3").scale(2) h9= MathTex("9","=","\\theta_0","+","{\\theta_1}","3").scale(2) self.wait(1) self.play(Write(h)) self.wait(6) framebox1 = SurroundingRectangle(h[4], buff = .1) # theta_1 framebox2 = SurroundingRectangle(h[2], buff = .1) # theta_0 framebox3 = SurroundingRectangle(h[0], buff = .1) # left-side self.play( ShowCreation(framebox1), ) self.wait(2) self.play( ReplacementTransform(framebox1,framebox2), ) self.wait(1) self.play( ReplacementTransform(framebox2,framebox3), ) self.wait(3) self.play(FadeOut(framebox3)) self.wait(1) self.play(ReplacementTransform(h, h3)) self.wait(3) self.play(ReplacementTransform(h3, h9)) self.wait(1) self.play(FadeOut(h9)) class CombinacionJmin(Scene): def construct(self): Jmin1 = Tex(r'$\min{J\left(\theta_0, \theta_1)}\to{9}$').scale(2) Jmin2 = Tex(r'$\min{J\left(0, 7)}\to{9}$').scale(2) Jmin3 = Tex(r'$\min{J\left(2, 5)}\to{9}$').scale(2) Jmin4 = Tex(r'$\min{J\left(-3, 1)}\to{9}$').scale(2) Jmin5 = Tex(r'$\min{J\left(-2, 2)}\to{9}$').scale(2) Jmin6 = Tex(r'$\min{J\left(-1, 3)}\to{9}$').scale(2) self.wait(1) self.play(Write(Jmin1)) self.wait(3) self.play(ReplacementTransform(Jmin1, Jmin2)) self.play(ReplacementTransform(Jmin2, Jmin3)) self.play(ReplacementTransform(Jmin3, Jmin4)) self.play(ReplacementTransform(Jmin4, Jmin5)) self.play(ReplacementTransform(Jmin5, Jmin6)) self.wait(8) self.play(FadeOut(Jmin6)) class SomDifferencia(Scene): def construct(self): # 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 Jmin = MathTex("J(\\theta_{0}, \\theta_{1})", "=", "\\frac{1}{2m}", "\\sum\\limits_{i=1}^m", "(", "h_{\\theta}(x^{(i)})", "-", "y^{(i)}", ")^2").scale(1) framebox1 = SurroundingRectangle(Jmin[3], buff = .1) # sum framebox2 = SurroundingRectangle(Jmin[5], buff = .1) # h framebox3 = SurroundingRectangle(Jmin[7], buff = .1) # y self.wait(1) self.play(Write(Jmin)) self.wait(5) self.play( ShowCreation(framebox1), ) self.wait(1) self.play( ReplacementTransform(framebox1,framebox2), ) self.play( ReplacementTransform(framebox2,framebox3), ) self.wait(2) self.play(FadeOut(framebox3)) self.wait(9) self.play(FadeOut(Jmin)) class GradientDescentDilanti(MovingCameraScene): def construct(self): gd = PangoText('Gradient Descent', gradient=(BLUE, GREEN)).scale(2) self.wait(3) self.play(Write(gd)) self.wait(1) self.play(self.camera_frame.set_width, gd.get_width() * 1.2) self.wait(2) self.play(FadeOut(gd)) self.wait(1) Jmin1 = Tex(r'$\min{J\left(\theta_0, \theta_1)}$').scale(1) Jmin2 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2)}$').scale(1) Jmin3 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2, \theta_3)}$').scale(1) Jmin4 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2, \theta_3, \theta_4)}$').scale(1) Jmin5 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2, \theta_3, \theta_4, \theta_5)}$').scale(1) Jmin6 = Tex(r'$\min{J\left(\theta_0, \theta_1)}$').scale(1) self.play(Write(Jmin1)) self.wait(1) self.play(ReplacementTransform(Jmin1, Jmin2)) self.play(ReplacementTransform(Jmin2, Jmin3)) self.play(ReplacementTransform(Jmin3, Jmin4)) self.play(ReplacementTransform(Jmin4, Jmin5)) self.wait(1) self.play(ReplacementTransform(Jmin5, Jmin6)) self.wait(3) self.play(FadeOut(Jmin6)) class KedaRipitiYGana(MovingCameraScene): def construct(self): thetaJ = Tex(r'$\theta_j :=$', r'$\theta_j - \alpha \frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right)$').scale(1) simul = Tex(r'(update parew pa $j=0$ y $j=1$ !)').scale(0.75).move_to(2 * DOWN) self.wait(1) self.play(Write(thetaJ), Write(simul)) brace1 = Brace(thetaJ[1], UP, buff=SMALL_BUFF) t1 = brace1.get_text("ripiti te ora e converge") self.play( GrowFromCenter(brace1), FadeIn(t1), ) self.wait(9) self.play(FadeOut(t1), FadeOut(brace1), FadeOut(simul)) self.play(self.camera_frame.set_width, thetaJ.get_width() * 1.6) self.play(FadeOut(thetaJ)) class CordaCalculus(GraphScene): CONFIG = { "y_axis_label": r"$y$", "x_axis_label": r"$x$", "y_max": 10, "y_min": 0, "y_tick_frequency" : 1, "x_max": 10, "x_min": 0, "axes_color" : BLUE } def construct(self): self.wait(1) deriv = Tex(r'$\frac{dy}{dx}$').scale(3) self.play(Write(deriv)) self.wait(5) self.play(FadeOut(deriv)) self.setup_axes(animate=True) def graph_to_be_drawn(x): return (x-5)**2 def dx(x): return 2*(x-5) parabola = self.get_graph( lambda x: graph_to_be_drawn(x), x_min=2, x_max=8, color=YELLOW, stroke_opacity=0.5) vt = ValueTracker(0) def moving_dot(): x = vt.get_value() d = Dot(color=WHITE).move_to(self.coords_to_point(x, graph_to_be_drawn(x))) return d md = always_redraw(moving_dot) def get_w_line(): t = TangentLine(md, 1.0, length=2, stroke_opacity=1, color=RED) x = vt.get_value() t.move_to(self.coords_to_point(x, graph_to_be_drawn(x))) # seems to be some rounding error? dx(x) is correct (manshrug) inter = match_interpolate(0.6, -0.6, 3, 7, x) t.rotate(math.atan2(-1,dx(x+inter))) return t vt.set_value(3) line = always_redraw(get_w_line) self.play(ShowCreation(parabola), FadeIn(md), FadeIn(line)) self.wait(1) self.play(vt.set_value, 7, rate_func=there_and_back, run_time=4) self.wait(1) self.play(vt.set_value, 5, rate_func=slow_into, run_time=4) self.wait(6) self.play(vt.set_value, 3, rate_func=slow_into, run_time=1) self.play(vt.set_value, 5, rate_func=slow_into, run_time=6) self.wait(1) self.play(FadeOut(parabola), FadeOut(md), FadeOut(line), FadeOut(self.axes)) class MinTekenMeiMei(MovingCameraScene): def construct(self): thetaJ = Tex(r'$\theta_j :=$', r'$\theta_j - \alpha \frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right)$').scale(1) simul = Tex(r'(update parew pa $j=0$ y $j=1$ !)').scale(0.75).move_to(2 * DOWN) brace1 = Brace(thetaJ[1], UP, buff=SMALL_BUFF) framebox1 = SurroundingRectangle(thetaJ[1], buff = .1) # theta_1 t1 = brace1.get_text("ripiti te ora e converge") self.wait(1) self.play( Write(thetaJ), Write(simul), GrowFromCenter(brace1), FadeIn(t1)) self.play(FadeIn(framebox1)) self.wait(3) thetaPlus = Tex(r'$\frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right) > 0 \to$ descent').scale(1) self.play( FadeOut(brace1), FadeOut(t1), FadeOut(simul), FadeOut(framebox1), ReplacementTransform(thetaJ, thetaPlus)) self.wait(4) thetaMin = Tex(r'$\frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right) < 0 \to$ ascent').scale(1) self.play(ReplacementTransform(thetaPlus, thetaMin)) self.wait(4) thetaJ2 = Tex(r'$\theta_j :=$', r'$\theta_j - \alpha \frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right)$').scale(1) self.play(ReplacementTransform(thetaMin, thetaJ2)) self.wait(4) self.play(FadeOut(thetaJ2)) class Alpha(GraphScene): CONFIG = { "y_axis_label": r"$y$", "x_axis_label": r"$x$", "y_max": 10, "y_min": 0, "y_tick_frequency" : 1, "x_max": 10, "x_min": 0, "axes_color" : BLUE } def construct(self): self.wait(1) alpha = Tex(r'$\alpha$').scale(3).shift(0) self.play(Write(alpha)) self.play(ApplyMethod(alpha.shift, (UP+RIGHT)*PI)) self.setup_axes(animate=True) def graph_to_be_drawn(x): return (x-5)**2 def dx(x): return 2*(x-5) parabola = self.get_graph( lambda x: graph_to_be_drawn(x), x_min=2, x_max=8, color=YELLOW, stroke_opacity=0.5) vt = ValueTracker(0) def moving_dot(): x = vt.get_value() d = Dot(color=WHITE).move_to(self.coords_to_point(x, graph_to_be_drawn(x))) return d md = always_redraw(moving_dot) def get_w_line(): t = TangentLine(md, 1.0, length=2, stroke_opacity=1, color=RED) x = vt.get_value() t.move_to(self.coords_to_point(x, graph_to_be_drawn(x))) # seems to be some rounding error? dx(x) is correct (manshrug) inter = match_interpolate(0.6, -0.6, 3, 7, x) t.rotate(math.atan2(-1,dx(x+inter))) return t vt.set_value(3) line = always_redraw(get_w_line) self.play( ShowCreation(parabola), FadeIn(md), FadeIn(line), ApplyMethod(alpha.scale, (1/2))) alpha2 = Tex(r'$\alpha = 2.0$') self.play(ReplacementTransform(alpha, alpha2), run_time=0.5) self.play(vt.set_value, 6, rate_func=there_and_back, run_time=2) alpha3 = Tex(r'$\alpha = 1.1$') self.play(ReplacementTransform(alpha2, alpha3), run_time=0.5) self.play(vt.set_value, 4, rate_func=there_and_back, run_time=2) self.play(vt.set_value, 5, rate_func=slow_into, run_time=12) alpha4 = Tex(r'$\alpha = 0.05$') self.play(ReplacementTransform(alpha3, alpha4), run_time=0.5) self.play(vt.set_value, 4, rate_func=there_and_back, run_time=6) alpha5 = Tex(r'$\alpha = 2.2$') self.play(ReplacementTransform(alpha4, alpha5), run_time=0.5) self.play(vt.set_value, 3, rate_func=rush_into, run_time=2) self.play(vt.set_value, 7, rate_func=rush_into, run_time=3) self.wait(3) self.play( FadeOut(parabola), FadeOut(md), FadeOut(line), FadeOut(self.axes), FadeOut(alpha5)) class SaddlePoint(ThreeDScene): def construct(self): axes = ThreeDAxes(animate=True) surface = ThreeDSurface() self.set_camera_orientation(phi=75 * DEGREES, theta=30 * DEGREES, distance=30) self.begin_ambient_camera_rotation(rate=0.1) self.wait(1) self.play(ShowCreation(axes)) self.play(ShowCreation(surface)) self.wait(50) self.play(FadeOut(surface)) self.play(FadeOut(axes)) class TipoDiGradientDescent(Scene): def construct(self): grad1 = Tex(r'SGD').scale(2).shift(0) grad2 = Tex(r'RMSprop').scale(2).shift(0) grad3 = Tex(r'Adam').scale(2).shift(0) grad4 = Tex(r'Adadelta').scale(2).shift(0) grad5 = Tex(r'Adagrad').scale(2).shift(0) grad6 = Tex(r'Adamax').scale(2).shift(0) grad7 = Tex(r'Nadam').scale(2).shift(0) grad8 = Tex(r'Ftrl').scale(2).shift(0) grad9 = Tex(r'BGD').scale(2).shift(0) grad2.next_to(grad1, DOWN*1.5) grad3.next_to(grad1, UP*1.5) grad4.next_to(grad1, LEFT*1.5) grad5.next_to(grad1, RIGHT*1.5) grad6.next_to(grad4, UP*2) grad7.next_to(grad5, UP*2) grad8.next_to(grad5, DOWN*2) grad9.next_to(grad4, DOWN*2) self.wait(1) self.play(Write(grad1)) self.play(Write(grad2)) self.play(Write(grad3)) self.play(Write(grad4)) self.play(Write(grad5)) self.play(Write(grad6)) self.play(Write(grad7)) self.play(Write(grad8)) self.play(Write(grad9)) self.wait(12) self.play( FadeOut(grad9), FadeOut(grad8), FadeOut(grad7), FadeOut(grad6), FadeOut(grad5), FadeOut(grad4), FadeOut(grad3), FadeOut(grad2), FadeOut(grad1) ) class Outro(Scene): def construct(self): outroText = PangoText('Masha Danki!', gradient=(BLUE, GREEN)).scale(2) self.wait(1) self.add(outroText) self.play(Write(outroText)) self.wait(3) self.play(FadeOut(outroText)) self.wait(1)