Skip to content

Instantly share code, notes, and snippets.

@robintux
Created April 11, 2022 02:24
Show Gist options
  • Save robintux/9fa996617882695d9bdaad5bb204a4ba to your computer and use it in GitHub Desktop.
Save robintux/9fa996617882695d9bdaad5bb204a4ba to your computer and use it in GitHub Desktop.
visualizar_clasificador
def visualizar_clasificador(clasificador, X, y):
#definimos los máximos valores de X e y para la malla
min_x, max_x = X[:, 0].min() - 1.0, X[:, 0].max() + 1.0
min_y, max_y = X[:, 1].min() - 1.0, X[:, 1].max() + 1.0
#definimos el paso de la malla
mesh_step_size = 0.01
#definimos la malla para x e y
x_vals, y_vals = np.mgrid[min_x:max_x:mesh_step_size, min_y:max_y:mesh_step_size]
#corremos el clasificador sobre la malla
resultados = clasificador.predict(np.c_[x_vals.ravel(), y_vals.ravel()])
#reordenamos la salida
print(resultados)
resultados = resultados.reshape(x_vals.shape)
#creamos la figura
plt.figure()
#elegimos los colores
plt.pcolormesh(x_vals,y_vals,resultados,cmap=plt.cm.PiYG)
#ubicamos los puntos
plt.scatter(X[:,0],X[:,1],c=y,s=75,edgecolors='black',linewidth=1,cmap=plt.cm.PiYG)
#especificamos los límites de la gráfica
plt.xlim(x_vals.min(), x_vals.max())
plt.ylim(y_vals.min(), y_vals.max())
#especificamos los puntos que se visualizarán sobre los ejes
plt.xticks((np.arange(int(X[:, 0].min() - 1), int(X[:, 0].max() + 1), 1.0)))
plt.yticks((np.arange(int(X[:, 1].min() - 1), int(X[:, 1].max() + 1), 1.0)))
#Graficamos
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment