Skip to content

Instantly share code, notes, and snippets.

@cfperez
Last active September 10, 2015 18:09
Show Gist options
  • Save cfperez/ff0cd71418ab5e3b4151 to your computer and use it in GitHub Desktop.
Save cfperez/ff0cd71418ab5e3b4151 to your computer and use it in GitHub Desktop.
radar_plot() wrapper function around matplotlib example
from __future__ import print_function
import matplotlib.pyplot as plt
from matplotlib.path import Path
from matplotlib.spines import Spine
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
import numpy as np
def radar_plot(data, axes=[], colors=[], labels=[], ax=None, frame='circle', **kwargs):
'''Create radar plot from 2D array. Each row is a closed line; each column is a feature/axis.
Parameters
----------
data 2D array with dimensions (samples, axes)
axes names of axes
colors list of mpl color codes for each sample
labels list of labels for each line
fill bool: filles lines
alpha float 0 to 1: fill transparency
'''
try:
from pandas import DataFrame
if isinstance(data, DataFrame):
return radar_plot(data.values, axes=data.columns, labels=data.index, colors=colors, ax=ax, frame=frame, **kwargs)
except ImportError:
pass
num_features = max(map(len, data))
num_axes, num_colors, num_labels = len(axes), len(colors), len(labels)
if num_axes > 0 and num_features != num_axes:
raise ValueError('Number of named axes (%d) must match width of array (%d)' % (num_axes,num_features))
if num_labels > 0 and num_labels != len(data):
raise ValueError('Number of labels (%d) must equal length (number of rows) in data (%d)' % (len(labels),len(data)))
if num_colors > 0 and len(data) != num_colors:
raise ValueError('Number of colors (%d) must match data dimensions %d' % (num_colors,num_features))
theta = radar_factory(num_features, frame=frame)
if not ax:
fig = plt.figure()
ax = fig.add_subplot(111, projection='radar')
return ax.plotall(theta, data, labels=labels, axes=axes, colors=colors, **kwargs)
def radar_factory(num_vars, frame='circle'):
"""Create a radar chart with `num_vars` axes.
This function creates a RadarAxes projection and registers it.
Parameters
----------
num_vars : int
Number of variables for radar chart.
frame : {'circle' | 'polygon'}
Shape of frame surrounding axes.
"""
# calculate evenly-spaced axis angles
theta = 2*np.pi * np.linspace(0, 1-1./num_vars, num_vars)
# rotate theta such that the first axis is at the top
theta += np.pi/2
def draw_poly_patch(self):
verts = unit_poly_verts(theta)
return plt.Polygon(verts, closed=True, edgecolor='k')
def draw_circle_patch(self):
# unit circle centered on (0.5, 0.5)
return plt.Circle((0.5, 0.5), 0.5)
patch_dict = {'polygon': draw_poly_patch, 'circle': draw_circle_patch}
if frame not in patch_dict:
raise ValueError('unknown value for `frame`: %s' % frame)
class RadarAxes(PolarAxes):
name = 'radar'
# use 1 line segment to connect specified points
RESOLUTION = 1
# define draw_frame method
draw_patch = patch_dict[frame]
def fill(self, *args, **kwargs):
"""Override fill so that line is closed by default"""
closed = kwargs.pop('closed', True)
return super(RadarAxes, self).fill(closed=closed, *args, **kwargs)
def plotall(self, theta, data, colors=[], axes=[], labels=[], *args, **kwargs):
from itertools import izip_longest
fill = kwargs.pop('fill', kwargs.has_key('alpha'))
alpha = kwargs.pop('alpha', 1.0)
if colors and len(colors) != len(data):
raise ValueError('color must be same length as data')
for x,color,label in izip_longest(data, colors, labels):
lines = self.plot(theta, x, label=label, color=color, *args, **kwargs)
if fill:
c = color or lines[-1].get_color()
self.fill(theta, x, facecolor=c, alpha=alpha)
if len(axes) > 0:
self.set_varlabels(axes)
def plot(self, *args, **kwargs):
"""Override plot so that line is closed by default"""
lines = super(RadarAxes, self).plot(*args, **kwargs)
for line in lines:
self._close_line(line)
return lines
def _close_line(self, line):
x, y = line.get_data()
# FIXME: markers at x[0], y[0] get doubled-up
if x[0] != x[-1]:
x = np.concatenate((x, [x[0]]))
y = np.concatenate((y, [y[0]]))
line.set_data(x, y)
def set_varlabels(self, labels):
self.set_thetagrids(theta * 180/np.pi, labels)
def _gen_axes_patch(self):
return self.draw_patch()
def _gen_axes_spines(self):
if frame == 'circle':
return PolarAxes._gen_axes_spines(self)
# The following is a hack to get the spines (i.e. the axes frame)
# to draw correctly for a polygon frame.
# spine_type must be 'left', 'right', 'top', 'bottom', or `circle`.
spine_type = 'circle'
verts = unit_poly_verts(theta)
# close off polygon by repeating first vertex
verts.append(verts[0])
path = Path(verts)
spine = Spine(self, spine_type, path)
spine.set_transform(self.transAxes)
return {'polar': spine}
register_projection(RadarAxes)
return theta
def unit_poly_verts(theta):
"""Return vertices of polygon for subplot axes.
This polygon is circumscribed by a unit circle centered at (0.5, 0.5)
"""
x0, y0, r = [0.5] * 3
verts = [(r*np.cos(t) + x0, r*np.sin(t) + y0) for t in theta]
return verts
if __name__ == '__main__':
print("Radar plot test")
fig = plt.figure()
ax = fig.add_subplot(221, projection='radar')
radar_plot([[0.87, 0.01, 0.08, 0.00, 0.00, 0.04, 0.00, 0.00, 0.01],
[0.09, 0.95, 0.02, 0.03, 0.00, 0.01, 0.13, 0.06, 0.00],
[0.01, 0.02, 0.71, 0.24, 0.13, 0.16, 0.00, 0.50, 0.00],
[0.01, 0.03, 0.00, 0.28, 0.24, 0.23, 0.00, 0.44, 0.88],
[0.02, 0.00, 0.18, 0.45, 0.64, 0.55, 0.86, 0.00, 0.16]]
, axes=['Sulfate', 'Nitrate', 'EC', 'OC1', 'OC2', 'OC3', 'OP', 'CO', 'O3']
, fill=True, alpha=.25,
ax=ax)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment