Python pyplot receiver operating characteristic (ROC) curve with colorbar
import numbers
import six
import numpy
import matplotlib.collections
from matplotlib import pyplot
# using example from
def make_segments(x, y):
Create list of line segments from x and y coordinates,
in the correct format for LineCollection:
an array of the form
numlines x (points per line) x 2 (x and y) array
points = numpy.array([x, y]).T.reshape(-1, 1, 2)
segments = numpy.concatenate([points[:-1], points[1:]], axis=1)
return segments
def colorline(x, y, z=None, axes=None,
norm=pyplot.Normalize(0.0, 1.0), linewidth=3, alpha=1.0,
Plot a colored line with coordinates x and y
Optionally specify colors in the array z
Optionally specify a colormap, a norm function and a line width
# Default colors equally spaced on [0,1]:
if z is None:
z = numpy.linspace(0.0, 1.0, len(x))
# Special case if a single number:
if isinstance(z, numbers.Real):
z = numpy.array([z])
z = numpy.asarray(z)
segments = make_segments(x, y)
lc = matplotlib.collections.LineCollection(
segments, array=z, cmap=cmap, norm=norm,
linewidth=linewidth, alpha=alpha, **kwargs
if axes is None:
axes = pyplot.gca()
return lc
def plot_roc(tpr, fpr, thresholds, subplots_kwargs=None,
label_every=None, label_kwargs=None,
fpr_label='False Positive Rate',
tpr_label='True Positive Rate',
title='Receiver operating characteristic',
if subplots_kwargs is None:
subplots_kwargs = {}
figure, axes = pyplot.subplots(1, 1, **subplots_kwargs)
if 'lw' not in kwargs:
kwargs['lw'] = 1
axes.plot(fpr, tpr, **kwargs)
if label_every is not None:
if label_kwargs is None:
label_kwargs = {}
if 'bbox' not in label_kwargs:
label_kwargs['bbox'] = dict(
boxstyle='round,pad=0.5', fc='yellow', alpha=0.5,
for k in six.moves.range(len(tpr)):
if k % label_every != 0:
threshold = str(numpy.round(thresholds[k], 2))
x = fpr[k]
y = tpr[k]
axes.annotate(threshold, (x, y), **label_kwargs)
if luck_label is not None:
axes.plot((0, 1), (0, 1), '--', color='Gray', label=luck_label)
lc = colorline(fpr, tpr, thresholds, axes=axes)
axes.set_xlim([-0.05, 1.05])
axes.set_ylim([-0.05, 1.05])
axes.legend(loc="lower right")
return figure, axes
