Skip to content

Instantly share code, notes, and snippets.

@nathanielbd
Created January 8, 2021 23:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nathanielbd/9f61a02fd76b244cb180db171998c814 to your computer and use it in GitHub Desktop.
Save nathanielbd/9f61a02fd76b244cb180db171998c814 to your computer and use it in GitHub Desktop.
Animate the perceptron algorithm with Manim
#!/usr/bin/env python
# See the result: https://nathanielbd.github.io/Perceptron.mp4
from manimlib.imports import *
class Perceptron(GraphScene):
CONFIG = {
"x_min": -1.5,
"x_max": 1.5,
"y_min": -1.5,
"y_max": 1.5,
"x_tick_frequency": 0.1,
"y_tick_frequency": 0.1,
"graph_origin": ORIGIN,
"x_axis_label": "$X_1$",
"y_axis_label": "$X_2$"
}
def construct(self):
self.setup_axes(animate=True)
import pandas as pd
data = pd.read_csv('../nn-from-scratch/data1.csv')
dots = VGroup(*[Dot(point=self.coords_to_point(datum['X1'],datum['X2']), radius=0.05, color=BLUE) if datum['y'] == 1 else
Dot(point=self.coords_to_point(datum['X1'],datum['X2']), radius=0.05, color=RED) for _,datum in data.iterrows()])
self.play(Write(dots))
import numpy as np
X = np.vstack((np.array(data['X1']),np.array(data['X2']))).T
y = np.array(data['y'])
steps = 0
mistakes = 1
w = np.array([1.0,-1.0])
n = y.shape[0]
div = self.get_graph(
lambda x: -1*x*w[0]/w[1],
YELLOW
)
steps_label = TextMobject("step: 0").shift(UP*2).shift(LEFT*5)
mistakes_label = TextMobject("mistakes: 0").shift(LEFT*5).shift(UP)
self.play(
Write(div),
Write(steps_label),
Write(mistakes_label)
)
while mistakes > 0:
mistakes = 0
for i in range(0,10):
steps += 1
self.play(Transform(steps_label, TextMobject(f"step: {steps}").shift(steps_label.get_center())), run_time=0.2)
dp = Line(
ORIGIN,
self.coords_to_point(X[i][0], X[i][1]),
color = BLUE
)
self.play(
Transform(mistakes_label, TextMobject(f"mistakes: {mistakes}").shift(mistakes_label.get_center())),
Write(dp)
)
if np.inner(w,X[i])*y[i] < 0:
mistakes += 1
perp = Line(
ORIGIN,
self.coords_to_point(w[0], w[1]),
color = YELLOW
)
self.play(
Write(perp)
)
self.play(ApplyMethod(dp.shift, self.coords_to_point(w[0],w[1]))),
w += y[i]*X[i]
self.play(
Transform(perp, Line(
ORIGIN,
self.coords_to_point(w[0],w[1]),
color = YELLOW
))
)
self.play(
Transform(div, self.get_graph(
lambda x: -1*x*w[0]/w[1],
YELLOW
))
)
self.play(
FadeOutAndShiftDown(perp),
)
self.play(FadeOutAndShiftDown(dp))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment