Skip to content

Instantly share code, notes, and snippets.

@ajtulloch
Created November 26, 2013 09:10
Show Gist options
  • Save ajtulloch/7655467 to your computer and use it in GitHub Desktop.
Save ajtulloch/7655467 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import svmpy
import logging
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import itertools
import argh
def example(num_samples=10, num_features=2, grid_size=20, filename="svm.pdf"):
samples = np.matrix(np.random.normal(size=num_samples * num_features)
.reshape(num_samples, num_features))
labels = 2 * (samples.sum(axis=1) > 0) - 1.0
trainer = svmpy.SVMTrainer(svmpy.Kernel.linear(), 0.1)
predictor = trainer.train(samples, labels)
plot(predictor, samples, labels, grid_size, filename)
def plot(predictor, X, y, grid_size, filename):
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, grid_size),
np.linspace(y_min, y_max, grid_size),
indexing='ij')
flatten = lambda m: np.array(m).reshape(-1,)
result = []
for (i, j) in itertools.product(range(grid_size), range(grid_size)):
point = np.array([xx[i, j], yy[i, j]]).reshape(1, 2)
result.append(predictor.predict(point))
Z = np.array(result).reshape(xx.shape)
plt.contourf(xx, yy, Z,
cmap=cm.Paired,
levels=[-0.001, 0.001],
extend='both',
alpha=0.8)
plt.scatter(flatten(X[:, 0]), flatten(X[:, 1]),
c=flatten(y), cmap=cm.Paired)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.savefig(filename)
if __name__ == "__main__":
logging.basicConfig(level=logging.ERROR)
argh.dispatch_command(example)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment