Skip to content

Instantly share code, notes, and snippets.

@NWChen
Created December 30, 2018 03:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NWChen/b28d49aa5381006bcf49e936a8b99843 to your computer and use it in GitHub Desktop.
Save NWChen/b28d49aa5381006bcf49e936a8b99843 to your computer and use it in GitHub Desktop.
Python Rosenblatt Perceptron
import matplotlib.pyplot as plt
import numpy as np
X = np.array([
[0.8, 0.4],
[0.3, 0.1],
[0.8, 0.8],
[0.4, 0.6],
[0.6, 0.8],
[0.4, 0.2],
[0.4, 0.5],
])
Y = np.array([0, 0, 1, 1, 1, 0, 1])
def plot_points(X, Y, ax):
for i, x in enumerate(X):
ax.scatter(x[0], x[1], s=120,
marker=('_' if Y[i] <= 0 else '+'), linewidths=2,
c=('r' if Y[i] <= 0 else 'b')
)
fig, ax = plt.subplots(figsize=(6, 6))
plot_points(X, Y, ax)
def update(X, Y):
w = np.zeros(X.shape[1]+1)
epochs = 100
#fig, axes = plt.subplots(figsize=(5, 30), nrows=epochs, ncols=1)
for e in range(epochs):
for x, y in zip(X, Y):
pred = np.where((np.dot(w[:2], x)+w[2]) >= 0.0, 1, 0)
w[:2] += eta*(y-pred) * x
w[2] += eta*(y-pred)
return w
def predict(w, x):
return np.where((np.dot(w[:2], x)+w[2]) > 0.0, 1, 0)
w = update(X, Y)
for a in range(0, 10):
for b in range(0, 10):
i, j = a/10, b/10
p = predict(w, [i, j])
plt.scatter(i, j, s=120, marker=('_' if p <= 0 else '+'), linewidths=2,
c=('r' if p <= 0 else 'b')
)
@NWChen
Copy link
Author

NWChen commented Dec 30, 2018

(X, y):

hyperplane:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment