Last active
September 10, 2015 18:09
-
-
Save cfperez/ff0cd71418ab5e3b4151 to your computer and use it in GitHub Desktop.
radar_plot() wrapper function around matplotlib example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import matplotlib.pyplot as plt | |
from matplotlib.path import Path | |
from matplotlib.spines import Spine | |
from matplotlib.projections.polar import PolarAxes | |
from matplotlib.projections import register_projection | |
import numpy as np | |
def radar_plot(data, axes=[], colors=[], labels=[], ax=None, frame='circle', **kwargs): | |
'''Create radar plot from 2D array. Each row is a closed line; each column is a feature/axis. | |
Parameters | |
---------- | |
data 2D array with dimensions (samples, axes) | |
axes names of axes | |
colors list of mpl color codes for each sample | |
labels list of labels for each line | |
fill bool: filles lines | |
alpha float 0 to 1: fill transparency | |
''' | |
try: | |
from pandas import DataFrame | |
if isinstance(data, DataFrame): | |
return radar_plot(data.values, axes=data.columns, labels=data.index, colors=colors, ax=ax, frame=frame, **kwargs) | |
except ImportError: | |
pass | |
num_features = max(map(len, data)) | |
num_axes, num_colors, num_labels = len(axes), len(colors), len(labels) | |
if num_axes > 0 and num_features != num_axes: | |
raise ValueError('Number of named axes (%d) must match width of array (%d)' % (num_axes,num_features)) | |
if num_labels > 0 and num_labels != len(data): | |
raise ValueError('Number of labels (%d) must equal length (number of rows) in data (%d)' % (len(labels),len(data))) | |
if num_colors > 0 and len(data) != num_colors: | |
raise ValueError('Number of colors (%d) must match data dimensions %d' % (num_colors,num_features)) | |
theta = radar_factory(num_features, frame=frame) | |
if not ax: | |
fig = plt.figure() | |
ax = fig.add_subplot(111, projection='radar') | |
return ax.plotall(theta, data, labels=labels, axes=axes, colors=colors, **kwargs) | |
def radar_factory(num_vars, frame='circle'): | |
"""Create a radar chart with `num_vars` axes. | |
This function creates a RadarAxes projection and registers it. | |
Parameters | |
---------- | |
num_vars : int | |
Number of variables for radar chart. | |
frame : {'circle' | 'polygon'} | |
Shape of frame surrounding axes. | |
""" | |
# calculate evenly-spaced axis angles | |
theta = 2*np.pi * np.linspace(0, 1-1./num_vars, num_vars) | |
# rotate theta such that the first axis is at the top | |
theta += np.pi/2 | |
def draw_poly_patch(self): | |
verts = unit_poly_verts(theta) | |
return plt.Polygon(verts, closed=True, edgecolor='k') | |
def draw_circle_patch(self): | |
# unit circle centered on (0.5, 0.5) | |
return plt.Circle((0.5, 0.5), 0.5) | |
patch_dict = {'polygon': draw_poly_patch, 'circle': draw_circle_patch} | |
if frame not in patch_dict: | |
raise ValueError('unknown value for `frame`: %s' % frame) | |
class RadarAxes(PolarAxes): | |
name = 'radar' | |
# use 1 line segment to connect specified points | |
RESOLUTION = 1 | |
# define draw_frame method | |
draw_patch = patch_dict[frame] | |
def fill(self, *args, **kwargs): | |
"""Override fill so that line is closed by default""" | |
closed = kwargs.pop('closed', True) | |
return super(RadarAxes, self).fill(closed=closed, *args, **kwargs) | |
def plotall(self, theta, data, colors=[], axes=[], labels=[], *args, **kwargs): | |
from itertools import izip_longest | |
fill = kwargs.pop('fill', kwargs.has_key('alpha')) | |
alpha = kwargs.pop('alpha', 1.0) | |
if colors and len(colors) != len(data): | |
raise ValueError('color must be same length as data') | |
for x,color,label in izip_longest(data, colors, labels): | |
lines = self.plot(theta, x, label=label, color=color, *args, **kwargs) | |
if fill: | |
c = color or lines[-1].get_color() | |
self.fill(theta, x, facecolor=c, alpha=alpha) | |
if len(axes) > 0: | |
self.set_varlabels(axes) | |
def plot(self, *args, **kwargs): | |
"""Override plot so that line is closed by default""" | |
lines = super(RadarAxes, self).plot(*args, **kwargs) | |
for line in lines: | |
self._close_line(line) | |
return lines | |
def _close_line(self, line): | |
x, y = line.get_data() | |
# FIXME: markers at x[0], y[0] get doubled-up | |
if x[0] != x[-1]: | |
x = np.concatenate((x, [x[0]])) | |
y = np.concatenate((y, [y[0]])) | |
line.set_data(x, y) | |
def set_varlabels(self, labels): | |
self.set_thetagrids(theta * 180/np.pi, labels) | |
def _gen_axes_patch(self): | |
return self.draw_patch() | |
def _gen_axes_spines(self): | |
if frame == 'circle': | |
return PolarAxes._gen_axes_spines(self) | |
# The following is a hack to get the spines (i.e. the axes frame) | |
# to draw correctly for a polygon frame. | |
# spine_type must be 'left', 'right', 'top', 'bottom', or `circle`. | |
spine_type = 'circle' | |
verts = unit_poly_verts(theta) | |
# close off polygon by repeating first vertex | |
verts.append(verts[0]) | |
path = Path(verts) | |
spine = Spine(self, spine_type, path) | |
spine.set_transform(self.transAxes) | |
return {'polar': spine} | |
register_projection(RadarAxes) | |
return theta | |
def unit_poly_verts(theta): | |
"""Return vertices of polygon for subplot axes. | |
This polygon is circumscribed by a unit circle centered at (0.5, 0.5) | |
""" | |
x0, y0, r = [0.5] * 3 | |
verts = [(r*np.cos(t) + x0, r*np.sin(t) + y0) for t in theta] | |
return verts | |
if __name__ == '__main__': | |
print("Radar plot test") | |
fig = plt.figure() | |
ax = fig.add_subplot(221, projection='radar') | |
radar_plot([[0.87, 0.01, 0.08, 0.00, 0.00, 0.04, 0.00, 0.00, 0.01], | |
[0.09, 0.95, 0.02, 0.03, 0.00, 0.01, 0.13, 0.06, 0.00], | |
[0.01, 0.02, 0.71, 0.24, 0.13, 0.16, 0.00, 0.50, 0.00], | |
[0.01, 0.03, 0.00, 0.28, 0.24, 0.23, 0.00, 0.44, 0.88], | |
[0.02, 0.00, 0.18, 0.45, 0.64, 0.55, 0.86, 0.00, 0.16]] | |
, axes=['Sulfate', 'Nitrate', 'EC', 'OC1', 'OC2', 'OC3', 'OP', 'CO', 'O3'] | |
, fill=True, alpha=.25, | |
ax=ax) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment