Skip to content

Instantly share code, notes, and snippets.

@warmlogic
Forked from cfperez/radar_plot.py
Last active August 29, 2015 14:26
Show Gist options
  • Save warmlogic/edc3e4cbcec09a1f9a41 to your computer and use it in GitHub Desktop.
Save warmlogic/edc3e4cbcec09a1f9a41 to your computer and use it in GitHub Desktop.
radar_plot() wrapper function around matplotlib example
from __future__ import print_function
import matplotlib.pyplot as plt
from matplotlib.path import Path
from matplotlib.spines import Spine
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
import numpy as np
import seaborn as sns
def radar_plot(data, feat_labels=[], line_labels=None, colors=None, ax=None, frame='circle', **kwargs):
'''Create radar plot from 2D array. Each row is a closed line; each column is a feature/axis.
Parameters
----------
data 2D array with dimensions (samples, features)
feat_labels names of features
colors list of mpl color codes for each sample
line_labels list of labels for each line
fill bool: filles lines
alpha float 0 to 1: fill transparency
'''
num_samples = len(data)
num_features = len(data[0])
if colors is None:
colors = sns.color_palette('spectral', num_samples)
num_feat_lab, num_colors = len(feat_labels), len(colors)
if num_feat_lab > 0 and num_features != num_feat_lab:
raise ValueError('Number of features labels %d must match data features dimension %d' % (num_feat_lab,num_features))
if num_colors > 0 and num_samples != num_colors:
raise ValueError('Number of colors %d must match data samples dimension %d' % (num_colors,num_samples))
theta = radar_factory(num_features, frame=frame)
if not ax:
fig = plt.figure()
ax = fig.add_subplot(111, projection='radar')
if line_labels is None:
line_labels = ['%d' % (x+1) for x in range(num_samples)]
ax.plotall(theta, data, feat_labels=feat_labels, line_labels=line_labels, colors=colors, **kwargs)
legend = plt.legend(loc=(0.9, .95), labelspacing=0.1)
plt.setp(legend.get_texts(), fontsize='small')
return
def radar_factory(num_vars, frame='circle'):
"""Create a radar chart with `num_vars` axes.
This function creates a RadarAxes projection and registers it.
Parameters
----------
num_vars : int
Number of variables for radar chart.
frame : {'circle' | 'polygon'}
Shape of frame surrounding axes.
"""
# calculate evenly-spaced axis angles
theta = 2*np.pi * np.linspace(0, 1-1./num_vars, num_vars)
# rotate theta such that the first axis is at the top
theta += np.pi/2
def draw_poly_patch(self):
verts = unit_poly_verts(theta)
return plt.Polygon(verts, closed=True, edgecolor='k')
def draw_circle_patch(self):
# unit circle centered on (0.5, 0.5)
return plt.Circle((0.5, 0.5), 0.5)
patch_dict = {'polygon': draw_poly_patch, 'circle': draw_circle_patch}
if frame not in patch_dict:
raise ValueError('unknown value for `frame`: %s' % frame)
class RadarAxes(PolarAxes):
name = 'radar'
# use 1 line segment to connect specified points
RESOLUTION = 1
# define draw_frame method
draw_patch = patch_dict[frame]
def fill(self, *args, **kwargs):
"""Override fill so that line is closed by default"""
closed = kwargs.pop('closed', True)
return super(RadarAxes, self).fill(closed=closed, *args, **kwargs)
def plotall(self, theta, data, feat_labels=[], line_labels=[], colors=[], *args, **kwargs):
from itertools import izip_longest
fill = kwargs.pop('fill', True)
alpha = kwargs.pop('alpha', .25)
if colors and len(colors) != len(data):
raise ValueError('color must be same length as data')
for x,color,label in izip_longest(data, colors, line_labels):
lines = self.plot(theta, x, label=label, color=color, *args, **kwargs)
if fill:
c = color or lines[-1].get_color()
self.fill(theta, x, facecolor=c, alpha=alpha)
if feat_labels:
self.set_varlabels(feat_labels)
def plot(self, *args, **kwargs):
"""Override plot so that line is closed by default"""
lines = super(RadarAxes, self).plot(*args, **kwargs)
for line in lines:
self._close_line(line)
return lines
def _close_line(self, line):
x, y = line.get_data()
# FIXME: markers at x[0], y[0] get doubled-up
if x[0] != x[-1]:
x = np.concatenate((x, [x[0]]))
y = np.concatenate((y, [y[0]]))
line.set_data(x, y)
def set_varlabels(self, line_labels):
self.set_thetagrids(theta * 180/np.pi, line_labels)
def _gen_axes_patch(self):
return self.draw_patch()
def _gen_axes_spines(self):
if frame == 'circle':
return PolarAxes._gen_axes_spines(self)
# The following is a hack to get the spines (i.e. the axes frame)
# to draw correctly for a polygon frame.
# spine_type must be 'left', 'right', 'top', 'bottom', or `circle`.
spine_type = 'circle'
verts = unit_poly_verts(theta)
# close off polygon by repeating first vertex
verts.append(verts[0])
path = Path(verts)
spine = Spine(self, spine_type, path)
spine.set_transform(self.transAxes)
return {'polar': spine}
register_projection(RadarAxes)
return theta
def unit_poly_verts(theta):
"""Return vertices of polygon for subplot axes.
This polygon is circumscribed by a unit circle centered at (0.5, 0.5)
"""
x0, y0, r = [0.5] * 3
verts = [(r*np.cos(t) + x0, r*np.sin(t) + y0) for t in theta]
return verts
if __name__ == '__main__':
print("Radar plot test")
fig = plt.figure()
ax = fig.add_subplot(221, projection='radar')
radar_plot([[0.87, 0.01, 0.08, 0.00, 0.00, 0.04, 0.00, 0.00, 0.01],
[0.09, 0.95, 0.02, 0.03, 0.00, 0.01, 0.13, 0.06, 0.00],
[0.01, 0.02, 0.71, 0.24, 0.13, 0.16, 0.00, 0.50, 0.00],
[0.01, 0.03, 0.00, 0.28, 0.24, 0.23, 0.00, 0.44, 0.88],
[0.02, 0.00, 0.18, 0.45, 0.64, 0.55, 0.86, 0.00, 0.16]]
, feat_labels=['Sulfate', 'Nitrate', 'EC', 'OC1', 'OC2', 'OC3', 'OP', 'CO', 'O3']
, fill=True, alpha=.25,
ax=ax)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment