Skip to content

Instantly share code, notes, and snippets.

@jcboyd
Created April 29, 2024 04:12
Show Gist options
  • Save jcboyd/3eba5c5b4daecaa576118398e273d56f to your computer and use it in GitHub Desktop.
Save jcboyd/3eba5c5b4daecaa576118398e273d56f to your computer and use it in GitHub Desktop.
SVM margin visualisation
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import multivariate_normal
np.random.seed(100)
num_samples = 20
X = np.concatenate((
multivariate_normal(mean=np.array([2, 2]),
cov=1.0*np.array([[1, 0], [0, 1]]),
size=num_samples),
multivariate_normal(mean=np.array([-2, -2]),
cov=1.0*np.array([[1, 0], [0, 1]]),
size=num_samples)))
y = np.array(num_samples * [1] + num_samples * [-1])
mins = np.min(X, axis=0)
maxs = np.max(X, axis=0)
delta = 0.25
xlims = [mins[0] - delta, maxs[0] + delta]
ylims = [mins[1] - delta, maxs[1] + delta]
def plot_data(ax, X, xlims, ylims):
num_samples = X.shape[0] // 2
ax.set_xlabel('x1') ; ax.set_ylabel('x2')
ax.set_aspect('equal', adjustable='box')
ax.set_xlim(xlims)
ax.set_ylim(ylims)
# plot training data
plt.scatter(X[:num_samples, 0], X[:num_samples, 1], color='orange', alpha=0.8)
plt.scatter(X[num_samples:, 0], X[num_samples:, 1], color='blue', alpha=0.8)
def plot_projection(ax, b, m, points_coords, xlims, ylims):
# start with line coordinates
xs = np.linspace(xlims[0], xlims[-1], 2)
ys = m * xs + b
# convert to vectors
line_vec = np.array([xs[-1] - xs[0], ys[-1] - ys[0]])
b_vec = np.array([0, b])
for coords in points_coords:
point_vec = coords - b_vec
# compute projection
proj = (point_vec.dot(line_vec) / line_vec.dot(line_vec)) * line_vec
proj_coords = b_vec + proj
# plot data
ax.plot(xs, ys, 'red', linestyle='dashed')
ax.plot([coords[0], proj_coords[0]],
[coords[1], proj_coords[1]],
linestyle='dashed', color='black', alpha=0.8)
w = point_vec - proj
w /= np.linalg.norm(w)
mag = np.abs((-1 - b) / w.dot(coords))
w *= mag
x, y = np.mean(xs), np.mean(ys)
ax.arrow(x, y, *w, width=0.05, linewidth=0.5, color='purple')
ax.text(x + w[0] + 0.1, y + w[1], '$||w||_2 = %.02f$' % mag,
ha='left', va='top')
idx1 = 19 # orange data
idx2 = 37 # blue data
fig, ax = plt.subplots()
plot_data(ax, X, xlims, ylims)
plot_projection(ax, -0.4, -1, [X[idx1], X[idx2]], xlims, ylims)
fig.savefig('fig1.png', bbox_inches='tight')
idx1 = 18 # orange data
idx2 = 24 # blue data
fig, ax = plt.subplots()
plot_data(ax, X, xlims, ylims)
plot_projection(ax, -0.1, -0.05, [X[idx1], X[idx2]], xlims, ylims)
fig.savefig('fig2.png', bbox_inches='tight')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment