Last active
May 4, 2020 16:24
-
-
Save emanuelgsouza/c220622d3eebbcf9e65ddfcbc66f4a6d to your computer and use it in GitHub Desktop.
Código de implementação de matriz de confusão para classificação binária
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
# 1 para grávida, 0 para não grávida | |
valores_reais = [1, 0, 1, 0, 0, 0, 1, 0, 1, 0] | |
valores_preditos = [1, 0, 0, 1, 0, 0, 1, 1, 1, 0] | |
def get_confusion_matrix(reais, preditos, labels): | |
""" | |
Uma função que retorna a matriz de confusão para uma classificação binária | |
Args: | |
reais (list): lista de valores reais | |
preditos (list): lista de valores preditos pelo modelos | |
labels (list): lista de labels a serem avaliados. | |
É importante que ela esteja presente, pois usaremos ela para entender | |
quem é a classe positiva e quem é a classe negativa | |
Returns: | |
Um numpy.array, no formato: | |
numpy.array([ | |
[ tp, fp ], | |
[ fn, tn ] | |
]) | |
""" | |
# não implementado | |
if len(labels) > 2: | |
return None | |
if len(reais) != len(preditos): | |
return None | |
# considerando a primeira classe como a positiva, e a segunda a negativa | |
true_class = labels[0] | |
negative_class = labels[1] | |
# valores preditos corretamente | |
tp = 0 | |
tn = 0 | |
# valores preditos incorretamente | |
fp = 0 | |
fn = 0 | |
for (indice, v_real) in enumerate(reais): | |
v_predito = preditos[indice] | |
# se trata de um valor real da classe positiva | |
if v_real == true_class: | |
tp += 1 if v_predito == v_real else 0 | |
fp += 1 if v_predito != v_real else 0 | |
else: | |
tn += 1 if v_predito == v_real else 0 | |
fn += 1 if v_predito != v_real else 0 | |
return np.array([ | |
# valores da classe positiva | |
[ tp, fp ], | |
# valores da classe negativa | |
[ fn, tn ] | |
]) | |
get_confusion_matrix(reais=valores_reais, preditos=valores_preditos, labels=[1,0]) | |
# array([[3, 1], [2, 4]]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment