Skip to content

Instantly share code, notes, and snippets.

@trappmartin
Created December 7, 2023 23:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save trappmartin/1ceffaef7091e9445cc35e61167c0db6 to your computer and use it in GitHub Desktop.
Save trappmartin/1ceffaef7091e9445cc35e61167c0db6 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from sklearn.svm import SVC
classifier = SVC(gamma=2, C=1, random_state=42, probability=True)
dataset = make_moons(noise=0.3, random_state=0)
figure = plt.figure()
X, y = dataset
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.4, random_state=42
)
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
clf = make_pipeline(StandardScaler(), classifier)
clf.fit(X_train, y_train)
grid_resolution = 300
xx, yy = np.meshgrid(
np.linspace(x_min, x_max, grid_resolution),
np.linspace(y_min, y_max, grid_resolution),
)
X_grid = np.vstack((xx.flatten(), yy.flatten())).T
yhat = clf.predict(X_grid)
conf = clf.predict_proba(X_grid)
color1 = conf[:,1].reshape(*xx.shape) * (yhat == 0)
color2 = conf[:,0].reshape(*xx.shape) * (yhat == 1)
alpha1 = 0.8 * (yhat == 0)
alpha2 = 0.8 * (yhat == 1)
fig, ax = plt.subplots()
ext = [x_min, x_max, y_min, y_max]
ax.imshow(color1, cmap='Reds_r', origin='lower',
extent=ext, alpha = alpha1)
ax.imshow(color2, cmap='Blues_r', origin='lower',
extent=ext, alpha = alpha2)
# Plot the training points
ax.scatter(
X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors="k", alpha=0.5,
)
# Plot the testing points
ax.scatter(
X_test[:, 0], X_test[:, 1], c=y_test,
cmap=cm_bright, edgecolors="w"
)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xticks(())
ax.set_yticks(())
ax.axis('off')
fig.tight_layout()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment