Created
April 11, 2022 02:24
-
-
Save robintux/9fa996617882695d9bdaad5bb204a4ba to your computer and use it in GitHub Desktop.
visualizar_clasificador
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def visualizar_clasificador(clasificador, X, y): | |
#definimos los máximos valores de X e y para la malla | |
min_x, max_x = X[:, 0].min() - 1.0, X[:, 0].max() + 1.0 | |
min_y, max_y = X[:, 1].min() - 1.0, X[:, 1].max() + 1.0 | |
#definimos el paso de la malla | |
mesh_step_size = 0.01 | |
#definimos la malla para x e y | |
x_vals, y_vals = np.mgrid[min_x:max_x:mesh_step_size, min_y:max_y:mesh_step_size] | |
#corremos el clasificador sobre la malla | |
resultados = clasificador.predict(np.c_[x_vals.ravel(), y_vals.ravel()]) | |
#reordenamos la salida | |
print(resultados) | |
resultados = resultados.reshape(x_vals.shape) | |
#creamos la figura | |
plt.figure() | |
#elegimos los colores | |
plt.pcolormesh(x_vals,y_vals,resultados,cmap=plt.cm.PiYG) | |
#ubicamos los puntos | |
plt.scatter(X[:,0],X[:,1],c=y,s=75,edgecolors='black',linewidth=1,cmap=plt.cm.PiYG) | |
#especificamos los límites de la gráfica | |
plt.xlim(x_vals.min(), x_vals.max()) | |
plt.ylim(y_vals.min(), y_vals.max()) | |
#especificamos los puntos que se visualizarán sobre los ejes | |
plt.xticks((np.arange(int(X[:, 0].min() - 1), int(X[:, 0].max() + 1), 1.0))) | |
plt.yticks((np.arange(int(X[:, 1].min() - 1), int(X[:, 1].max() + 1), 1.0))) | |
#Graficamos | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment