#!/usr/bin/env python | |
import svmpy | |
import logging | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
import itertools | |
import argh | |
def example(num_samples=10, num_features=2, grid_size=20, filename="svm.pdf"): | |
samples = np.matrix(np.random.normal(size=num_samples * num_features) | |
.reshape(num_samples, num_features)) | |
labels = 2 * (samples.sum(axis=1) > 0) - 1.0 | |
trainer = svmpy.SVMTrainer(svmpy.Kernel.linear(), 0.1) | |
predictor = trainer.train(samples, labels) | |
plot(predictor, samples, labels, grid_size, filename) | |
def plot(predictor, X, y, grid_size, filename): | |
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 | |
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 | |
xx, yy = np.meshgrid(np.linspace(x_min, x_max, grid_size), | |
np.linspace(y_min, y_max, grid_size), | |
indexing='ij') | |
flatten = lambda m: np.array(m).reshape(-1,) | |
result = [] | |
for (i, j) in itertools.product(range(grid_size), range(grid_size)): | |
point = np.array([xx[i, j], yy[i, j]]).reshape(1, 2) | |
result.append(predictor.predict(point)) | |
Z = np.array(result).reshape(xx.shape) | |
plt.contourf(xx, yy, Z, | |
cmap=cm.Paired, | |
levels=[-0.001, 0.001], | |
extend='both', | |
alpha=0.8) | |
plt.scatter(flatten(X[:, 0]), flatten(X[:, 1]), | |
c=flatten(y), cmap=cm.Paired) | |
plt.xlim(x_min, x_max) | |
plt.ylim(y_min, y_max) | |
plt.savefig(filename) | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.ERROR) | |
argh.dispatch_command(example) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment