Skip to content

Instantly share code, notes, and snippets.

@demacdolincoln
Created February 28, 2018 23:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save demacdolincoln/d09645d89b38e39e4670e53e8c79aa6f to your computer and use it in GitHub Desktop.
Save demacdolincoln/d09645d89b38e39e4670e53e8c79aa6f to your computer and use it in GitHub Desktop.
ex gráficos de classificação com o sklearn
# -*- coding: utf-8 -*-
"""
@author: lincoln
fontes:
http://scikit-learn.org/stable/auto_examples/svm/plot_rbf_parameters.html
http://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html
obs.: se for exibir as imagens no jupyter notebook ou qtconsole
from IPython.display import set_matplotlib_formats
set_matplotlib_format('png')
^ torna a exibição mais rápida
"""
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.cross_validation import train_test_split as tts
from sklearn import svm, datasets
from sklearn.tree import DecisionTreeClassifier as dtc
from sklearn.neighbors import KNeighborsClassifier as knn
###############################################################################
# preparando o dataset #
###############################################################################
X, y = datasets.make_classification(n_features=2, n_redundant=0,
n_informative=2,random_state=0,
n_clusters_per_class=2)
#X, y = datasets.make_moons(random_state=2)
rng = np.random.RandomState(1) # agitando as coisas
X += rng.uniform(size=X.shape)
###############################################################################
# separando o dataset para treino e teste #
###############################################################################
x_train, x_test, y_train, y_test = tts(X, y, test_size=0.4)
###############################################################################
# treinamento #
###############################################################################
krnl = "rbf" # kernel
#clf = svm.SVC(probability=True, kernel=krnl, C=1.1,gamma=1.1)
clf = dtc(random_state=1) # arvore de decisão
#clf = knn()
clf.fit(x_train, y_train) # treinamento
score = clf.score(x_test, y_test) # estatística de acerto
###############################################################################
# preparando o matplotlib #
###############################################################################
ax = plt.subplot()
# ajustando as cores
cm = plt.cm.RdBu
cm_bright = ListedColormap(['red', 'blue'])
cm_bright2 = ListedColormap(['pink', 'cyan'])
# tratando dos limites
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
###############################################################################
# plotando o contorno #
###############################################################################
h = 0.2 # step size in the mesh
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
z = z.reshape(xx.shape)
ax.contourf(xx, yy, z, cmap=cm, alpha=0.8)
ax.set_xlim(xx.min(), xx.max())
ax.set_ylim(yy.min(), yy.max())
###############################################################################
# plotando a área #
###############################################################################
x, y = np.meshgrid(np.linspace(xx.min(), xx.max(), 200),
np.linspace(yy.min(), yy.max(), 200))
Z = clf.predict(np.c_[x.ravel(), y.ravel()])
Z = Z.reshape(x.shape)
ax.pcolormesh(x, y, Z, cmap=plt.cm.RdBu, alpha=0.4)
ax.set_xlim(x.min(), x.max())
ax.set_ylim(y.min(), y.max())
###############################################################################
# plotando amostras de treinamento
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train,
cmap=cm_bright, label="treino")
# plotando amostras de teste
ax.scatter(x_test[:, 0], x_test[:, 1], c=y_test, cmap=cm_bright2,
alpha=0.6, label="teste")
# plotando os vetores de suporte
#vetores = clf.support_vectors_
#ax.scatter(vetores[:, 0], vetores[:, 1], marker='+', color='w',
# label="vetor de suporte")
###############################################################################
ax.set_title("SVM \n precisão: {0:.2f}% | kernel: {1}".format(score*100, krnl))
#ax.set_title("Decision Three \n precisão: {0:.2f}%".format(score*100))
ax.legend(loc=4, framealpha=0.5, fontsize='small')
ax.set_xticks(())
ax.set_yticks(())
#plt.savefig("KNN.png", transparent=True, format="png", dpi=300)
#plt.savefig("dtc.png", transparent=True, format="png", dpi=300)
#plt.savefig("SVM.png", transparent=True, format="png", dpi=300)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment