Skip to content

Instantly share code, notes, and snippets.

@matsuken92
Last active May 2, 2024 03:08
Show Gist options
  • Star 11 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save matsuken92/05c2bc24c9edf83334b0 to your computer and use it in GitHub Desktop.
Save matsuken92/05c2bc24c9edf83334b0 to your computer and use it in GitHub Desktop.
ROC Curve Animation
%matplotlib inline
import sys
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
from matplotlib import animation as ani
import sklearn.metrics as mt
def animate(nframe):
global num_frame
sys.stdout.write(str(int(float(nframe)/num_frame*100)) + "%, ")
plt.clf()
# xの最小値、最大値
xmin = 15
xmax = 45
# xの分割数
sx = num_frame * 2
# 現在位置
pos = nframe * 2
# x軸生成
xx = np.linspace(xmin, xmax, sx)
# 分布の準備
x1 = st.norm.pdf(xx, loc=25, scale=2)
x2 = st.norm.pdf(xx, loc=30, scale=4)
cx1 = st.norm.cdf(xx, loc=25, scale=2)
cx2 = st.norm.cdf(xx, loc=30, scale=4)
# Graph描画
plt.subplot(311)
plt.title("Density curve. x=%d"%xx[pos])
plt.xlim(xmin, xmax)
plt.ylim(0,0.22)
plt.plot(xx,x1,linewidth=2, zorder = 200)
plt.plot(xx,x2,linewidth=2, zorder = 200)
plt.plot([xx[pos], xx[pos]], [0,1], "k", linewidth=2)
plt.fill_between(xx[0:pos],x1[0:pos], color="lightblue", zorder = 10)
plt.fill_between(xx[0:pos],x2[0:pos], color="lightgreen", zorder = 100)
plt.subplot(312)
plt.title("Cumulative curve. x=%d"%xx[pos])
plt.xlim(xmin, xmax)
plt.ylim(0,1)
plt.plot(xx,cx1,linewidth=2)
plt.plot(xx,cx2,linewidth=2)
plt.plot([xx[pos], xx[pos]], [0,1], "k", linewidth=2, zorder=50)
plt.scatter(xx[pos],cx1[pos], c="b", s=30, zorder=100)
plt.scatter(xx[pos],cx2[pos], c="g", s=30, zorder=100)
plt.subplot(313)
plt.title("ROC Curve. 1-e1=%.3f, e2=%.3f"%(cx1[pos],cx2[pos]))
plt.xlim(0,1)
plt.ylim(0,1)
plt.plot(cx2,cx1, linewidth=2)
plt.scatter(cx2[pos],cx1[pos], c="b", s=30, zorder=100)
num_frame = 100
fig = plt.figure(figsize=(7,15))
anim = ani.FuncAnimation(fig, animate, frames=num_frame, blit=True)
anim.save('ROC_curve1.gif', writer='imagemagick', fps=5, dpi=64)
def animate(nframe):
global num_frame
plt.clf()
xmin = 10
xmax = 45
# xの分割数
sx = 200
# 現在位置
pos = nframe
# x軸生成
xx = np.linspace(xmin, xmax, sx)
mu1 = 10 + pos
mu2 = 30
sd1 = 2
sd2 = 4
# 分布の準備
x1 = st.norm.pdf(xx, loc=mu1, scale=sd1)
x2 = st.norm.pdf(xx, loc=mu2, scale=sd2)
cx1 = st.norm.cdf(xx, loc=mu1, scale=sd1)
cx2 = st.norm.cdf(xx, loc=mu2, scale=sd2)
# Graph描画
plt.subplot(211)
plt.title("Density curve. mu1=%d"%mu1)
plt.xlim(xmin, xmax)
plt.ylim(0,0.22)
plt.plot(xx,x1,linewidth=2, zorder = 200)
plt.plot(xx,x2,linewidth=2, zorder = 200)
plt.subplot(212)
plt.title("ROC Curve. mu1=%d"%(mu1))
plt.xlim(0,1)
plt.ylim(0,1)
plt.plot(cx2,cx1, linewidth=2)
num_frame = 35
fig = plt.figure(figsize=(7,10))
anim = ani.FuncAnimation(fig, animate, frames=num_frame, blit=True)
anim.save('ROC_curve2.gif', writer='imagemagick', fps=2, dpi=64)
def animate(nframe):
global num_frame
sys.stdout.write(str(int(float(nframe)/num_frame*100)) + "%, ")
plt.clf()
xmin = 10
xmax = 45
# xの分割数
sx = 200
# x軸生成
xx = np.linspace(xmin, xmax, sx)
mu1 = 20
mu2 = 30
sd1 = .5 * (11-nframe)
sd2 = 4
# 分布の準備
x1 = st.norm.pdf(xx, loc=mu1, scale=sd1)
x2 = st.norm.pdf(xx, loc=mu2, scale=sd2)
cx1 = st.norm.cdf(xx, loc=mu1, scale=sd1)
cx2 = st.norm.cdf(xx, loc=mu2, scale=sd2)
# Graph描画
plt.subplot(211)
plt.title("Density curve. sd1=%.3f"%sd1)
plt.xlim(xmin, xmax)
plt.ylim(0,0.22)
plt.plot(xx,x1,linewidth=2, zorder = 200)
plt.plot(xx,x2,linewidth=2, zorder = 200)
auc = mt.auc(cx2,cx1)
plt.subplot(212)
plt.title("ROC Curve. auc=%f"%(auc))
plt.xlim(0,1)
plt.ylim(0,1)
plt.plot(cx2,cx1, linewidth=2)
num_frame = 10
fig = plt.figure(figsize=(7,10))
anim = ani.FuncAnimation(fig, animate, frames=num_frame, blit=True)
anim.save('ROC_curve_auc.gif', writer='imagemagick', fps=1, dpi=64)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment