Created
February 2, 2018 14:38
-
-
Save doraneko94/c954167d06d9392cd1460fed32d93007 to your computer and use it in GitHub Desktop.
This code generates ROC curve and PR curve, and also prints AUCs individually.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math, random | |
from matplotlib import pyplot as plt | |
def normal_cdf(x, mu=0, sigma=1): | |
return (1 + math.erf((x - mu) / math.sqrt(2) / sigma)) / 2 | |
def inverse_normal_cdf(p, mu=0, sigma=1, tolerance=0.00001): | |
if mu != 0 or sigma != 1: | |
return mu + sigma * inverse_normal_cdf(p, tolerance=tolerance) | |
low_z, low_p = -10.0, 0 | |
hi_z, hi_p = 10.0, 1 | |
while hi_z - low_z > tolerance: | |
mid_z = (low_z + hi_z) / 2 | |
mid_p = normal_cdf(mid_z) | |
if mid_p < p: | |
low_z, low_p = mid_z, mid_p | |
elif mid_p > p: | |
hi_z, hi_p = mid_z, mid_p | |
else: | |
break | |
return mid_z | |
N = 10000 #総サンプル数 | |
ip = 0.1 #有病率 | |
mu1 = 20 #陽性サンプルの平均 | |
s1 = 5 #陽性サンプルの分散 | |
mu2 = 30 #陰性サンプルの平均 | |
s2 = 5 #陰性サンプルの分散 | |
posi = [] | |
nega = [] | |
posis = [] | |
negas = [] | |
x = [] | |
for i in range(int(N*ip)): | |
posi.append(int(inverse_normal_cdf(random.random(), mu1, s1))) | |
#posi.append(random.randint(10,30)) | |
for i in range(int(N*(1-ip))): | |
nega.append(int(inverse_normal_cdf(random.random(), mu2, s2))) | |
#nega.append(random.randint(20,40)) | |
for i in range(max(posi)+1): | |
posis.append(0) | |
for i in range(len(posi)): | |
posis[posi[i]] += 1 | |
for i in range(max(nega)+1): | |
negas.append(0) | |
for i in range(len(nega)): | |
negas[nega[i]] += 1 | |
for i in range(max(len(posis), len(negas))+1): | |
x.append(i) | |
for i in range(len(x)-len(posis)): | |
posis.append(0) | |
for i in range(len(x)-len(negas)): | |
negas.append(0) | |
plt.title("Datas") | |
plt.xlabel("values") | |
plt.ylabel("numbers") | |
plt.bar(x,posis, color="r") | |
plt.bar(x,negas, color="b") | |
plt.show() | |
rx = [] | |
ry = [] | |
px = [] | |
py = [] | |
rauc = 0 | |
pauc = 0 | |
for i in range(len(x)): | |
TP = sum(posis[:i+1]) | |
FN = sum(posis[i+1:]) | |
TN = sum(negas[i+1:]) | |
FP = sum(negas[:i+1]) | |
if TP+FN != 0: | |
Recall = TP / (TP + FN) | |
else: | |
Recall = 1 | |
if TP+FP != 0: | |
Precision = TP / (TP + FP) | |
else: | |
Precision = 1 | |
TPR = TP / (TP + FN) | |
FPR = FP / (FP + TN) | |
rx.append(FPR) | |
ry.append(TPR) | |
px.append(Recall) | |
py.append(Precision) | |
if i >= 1: | |
rauc += (ry[i]+ry[i-1])*(rx[i]-rx[i-1])/2 | |
pauc += (py[i]+py[i-1])*(px[i]-px[i-1])/2 | |
plt.title("ROC Curve") | |
plt.xlabel("False Positive Rate") | |
plt.ylabel("True Positive Rate") | |
plt.plot(rx,ry) | |
plt.show() | |
plt.title("PR Curve") | |
plt.xlabel("Recall") | |
plt.ylabel("Precision") | |
plt.plot(px,py) | |
plt.show() | |
print("AUC-ROC: " + str(rauc)) | |
print("AUC-PR: " + str(pauc)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment