Skip to content

Instantly share code, notes, and snippets.

@lethalbrains
Forked from code-of-kpp/plot_roc.py
Created April 15, 2018 14:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lethalbrains/cf6a33cbe105306b0c9cfb9bbb42ad03 to your computer and use it in GitHub Desktop.
Save lethalbrains/cf6a33cbe105306b0c9cfb9bbb42ad03 to your computer and use it in GitHub Desktop.
Python pyplot receiver operating characteristic (ROC) curve with colorbar
import numbers
import six
import numpy
import matplotlib.collections
from matplotlib import pyplot
# using example from
# http://nbviewer.ipython.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
def make_segments(x, y):
'''
Create list of line segments from x and y coordinates,
in the correct format for LineCollection:
an array of the form
numlines x (points per line) x 2 (x and y) array
'''
points = numpy.array([x, y]).T.reshape(-1, 1, 2)
segments = numpy.concatenate([points[:-1], points[1:]], axis=1)
return segments
def colorline(x, y, z=None, axes=None,
cmap=pyplot.get_cmap('coolwarm'),
norm=pyplot.Normalize(0.0, 1.0), linewidth=3, alpha=1.0,
**kwargs):
'''
Plot a colored line with coordinates x and y
Optionally specify colors in the array z
Optionally specify a colormap, a norm function and a line width
'''
# Default colors equally spaced on [0,1]:
if z is None:
z = numpy.linspace(0.0, 1.0, len(x))
# Special case if a single number:
if isinstance(z, numbers.Real):
z = numpy.array([z])
z = numpy.asarray(z)
segments = make_segments(x, y)
lc = matplotlib.collections.LineCollection(
segments, array=z, cmap=cmap, norm=norm,
linewidth=linewidth, alpha=alpha, **kwargs
)
if axes is None:
axes = pyplot.gca()
axes.add_collection(lc)
axes.autoscale()
return lc
def plot_roc(tpr, fpr, thresholds, subplots_kwargs=None,
label_every=None, label_kwargs=None,
fpr_label='False Positive Rate',
tpr_label='True Positive Rate',
luck_label='Luck',
title='Receiver operating characteristic',
**kwargs):
if subplots_kwargs is None:
subplots_kwargs = {}
figure, axes = pyplot.subplots(1, 1, **subplots_kwargs)
if 'lw' not in kwargs:
kwargs['lw'] = 1
axes.plot(fpr, tpr, **kwargs)
if label_every is not None:
if label_kwargs is None:
label_kwargs = {}
if 'bbox' not in label_kwargs:
label_kwargs['bbox'] = dict(
boxstyle='round,pad=0.5', fc='yellow', alpha=0.5,
)
for k in six.moves.range(len(tpr)):
if k % label_every != 0:
continue
threshold = str(numpy.round(thresholds[k], 2))
x = fpr[k]
y = tpr[k]
axes.annotate(threshold, (x, y), **label_kwargs)
if luck_label is not None:
axes.plot((0, 1), (0, 1), '--', color='Gray', label=luck_label)
lc = colorline(fpr, tpr, thresholds, axes=axes)
figure.colorbar(lc)
axes.set_xlim([-0.05, 1.05])
axes.set_ylim([-0.05, 1.05])
axes.set_xlabel(fpr_label)
axes.set_ylabel(tpr_label)
axes.set_title(title)
axes.legend(loc="lower right")
return figure, axes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment