Created
December 8, 2022 11:21
-
-
Save zaemyung/4bee83b94a93e966261dd456a9924c7a to your computer and use it in GitHub Desktop.
Plotting Radar graph
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import textwrap | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
class ComplexRadar(): | |
""" | |
Create a complex radar chart with different scales for each variable | |
Parameters | |
---------- | |
fig : figure object | |
A matplotlib figure object to add the axes on | |
variables : list | |
A list of variables | |
ranges : list | |
A list of tuples (min, max) for each variable | |
n_ring_levels: int, defaults to 5 | |
Number of ordinate or ring levels to draw | |
show_scales: bool, defaults to True | |
Indicates if we the ranges for each variable are plotted | |
format_cfg: dict, defaults to None | |
A dictionary with formatting configurations | |
""" | |
def __init__(self, fig, variables, ranges, n_ring_levels=5, show_scales=True, format_cfg=None): | |
# Default formatting | |
self.format_cfg = { | |
# Axes | |
# https://matplotlib.org/stable/api/figure_api.html | |
'axes_args': {}, | |
# Tick labels on the scales | |
# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.rgrids.html | |
'rgrid_tick_lbls_args': {'fontsize':8}, | |
# Radial (circle) lines | |
# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.grid.html | |
'rad_ln_args': {}, | |
# Angle lines | |
# https://matplotlib.org/3.2.2/api/_as_gen/matplotlib.lines.Line2D.html#matplotlib.lines.Line2D | |
'angle_ln_args': {}, | |
# Include last value (endpoint) on scale | |
'incl_endpoint':True, | |
# Variable labels (ThetaTickLabel) | |
'theta_tick_lbls':{'va':'top', 'ha':'center'}, | |
'theta_tick_lbls_txt_wrap':15, | |
'theta_tick_lbls_brk_lng_wrds':False, | |
'theta_tick_lbls_pad':25, | |
# Outer ring | |
# https://matplotlib.org/stable/api/spines_api.html | |
'outer_ring':{'visible':True, 'color':'#d6d6d6'} | |
} | |
if format_cfg is not None: | |
self.format_cfg = { k:(format_cfg[k]) if k in format_cfg.keys() else (self.format_cfg[k]) | |
for k in self.format_cfg.keys()} | |
# Calculate angles and create for each variable an axes | |
# Consider here the trick with having the first axes element twice (len+1) | |
angles = np.arange(0, 360, 360./len(variables)) | |
axes = [fig.add_axes([0.1,0.1,0.9,0.9], | |
polar=True, | |
label = "axes{}".format(i), | |
**self.format_cfg['axes_args']) for i in range(len(variables)+1)] | |
# Ensure clockwise rotation (first variable at the top N) | |
for ax in axes: | |
ax.set_theta_zero_location('N') | |
ax.set_theta_direction(-1) | |
ax.set_axisbelow(True) | |
# Writing the ranges on each axes | |
for i, ax in enumerate(axes): | |
# Here we do the trick by repeating the first iteration | |
j = 0 if (i==0 or i==1) else i-1 | |
ax.set_ylim(*ranges[j]) | |
# Set endpoint to True if you like to have values right before the last circle | |
grid = np.linspace(*ranges[j], num=n_ring_levels, | |
endpoint=self.format_cfg['incl_endpoint']) | |
gridlabel = ["{}".format(round(x,2)) for x in grid] | |
gridlabel[0] = "" # remove values from the center | |
lines, labels = ax.set_rgrids(grid, | |
labels=gridlabel, | |
angle=angles[j], | |
**self.format_cfg['rgrid_tick_lbls_args'] | |
) | |
ax.set_ylim(*ranges[j]) | |
ax.spines["polar"].set_visible(False) | |
ax.grid(visible=False) | |
if show_scales == False: | |
ax.set_yticklabels([]) | |
# Set all axes except the first one unvisible | |
for ax in axes[1:]: | |
ax.patch.set_visible(False) | |
ax.xaxis.set_visible(False) | |
# Setting the attributes | |
self.angle = np.deg2rad(np.r_[angles, angles[0]]) | |
self.ranges = ranges | |
self.ax = axes[0] | |
self.ax1 = axes[1] | |
self.plot_counter = 0 | |
# Draw (inner) circles and lines | |
self.ax.yaxis.grid(**self.format_cfg['rad_ln_args']) | |
# Draw outer circle | |
self.ax.spines['polar'].set(**self.format_cfg['outer_ring']) | |
# Draw angle lines | |
self.ax.xaxis.grid(**self.format_cfg['angle_ln_args']) | |
# ax1 is the duplicate of axes[0] (self.ax) | |
# Remove everything from ax1 except the plot itself | |
self.ax1.axis('off') | |
self.ax1.set_zorder(9) | |
# Create the outer labels for each variable | |
l, text = self.ax.set_thetagrids(angles, labels=variables) | |
# Beautify them | |
labels = [t.get_text() for t in self.ax.get_xticklabels()] | |
labels = ['\n'.join(textwrap.wrap(l, self.format_cfg['theta_tick_lbls_txt_wrap'], | |
break_long_words=self.format_cfg['theta_tick_lbls_brk_lng_wrds'])) for l in labels] | |
self.ax.set_xticklabels(labels, **self.format_cfg['theta_tick_lbls']) | |
for t,a in zip(self.ax.get_xticklabels(),angles): | |
if a == 0: | |
t.set_ha('center') | |
elif a > 0 and a < 180: | |
t.set_ha('left') | |
elif a == 180: | |
t.set_ha('center') | |
else: | |
t.set_ha('right') | |
self.ax.tick_params(axis='both', pad=self.format_cfg['theta_tick_lbls_pad']) | |
def _scale_data(self, data, ranges): | |
"""Scales data[1:] to ranges[0]""" | |
for d, (y1, y2) in zip(data[1:], ranges[1:]): | |
assert (y1 <= d <= y2) or (y2 <= d <= y1) | |
x1, x2 = ranges[0] | |
d = data[0] | |
sdata = [d] | |
for d, (y1, y2) in zip(data[1:], ranges[1:]): | |
sdata.append((d-y1) / (y2-y1) * (x2 - x1) + x1) | |
return sdata | |
def plot(self, data, *args, **kwargs): | |
"""Plots a line""" | |
sdata = self._scale_data(data, self.ranges) | |
self.ax1.plot(self.angle, np.r_[sdata, sdata[0]], *args, **kwargs) | |
self.plot_counter = self.plot_counter+1 | |
def fill(self, data, *args, **kwargs): | |
"""Plots an area""" | |
sdata = self._scale_data(data, self.ranges) | |
self.ax1.fill(self.angle, np.r_[sdata, sdata[0]], *args, **kwargs) | |
def use_legend(self, *args, **kwargs): | |
"""Shows a legend""" | |
self.ax1.legend(*args, **kwargs) | |
def set_title(self, title, pad=25, **kwargs): | |
"""Set a title""" | |
self.ax.set_title(title,pad=pad, **kwargs) | |
if __name__ == '__main__': | |
variables = ["Clarity", "Coherence", "Fluency", "Style", "Iterater"] | |
sari_scores = { | |
"No Edit": [23.6, 31.3, 25.96, 15.51, 29.88], | |
"Iterater": [27.5, 33.79, 36.39, 18.9, 43.22], | |
"Deliterater(-DA)": [51.48, 51.08, 61.49, 36.99, 62.54], | |
"Deliterater(+DA)": [58.7, 81.23, 73.95, 61.44, 64.09], | |
} | |
data = pd.DataFrame.from_dict(sari_scores, columns=variables, orient="index") | |
print(data) | |
min_max_per_variable = data.describe().T[['min', 'max']] | |
min_max_per_variable['min'] = min_max_per_variable['min'].apply(lambda x: int(x)) | |
min_max_per_variable['max'] = min_max_per_variable['max'].apply(lambda x: math.ceil(x)) | |
variables = data.columns | |
ranges = list(min_max_per_variable.itertuples(index=False, name=None)) | |
print(ranges) | |
format_cfg = { | |
'rad_ln_args': {'visible':False}, | |
'outer_ring': {'visible':False}, | |
'rgrid_tick_lbls_args': {'fontsize': 10}, | |
'theta_tick_lbls': {'fontsize': 10}, | |
'theta_tick_lbls_pad': 5 | |
} | |
print(variables) | |
fig1 = plt.figure(figsize=(3, 3), dpi=100) | |
radar = ComplexRadar(fig1, variables, ranges, show_scales=True, format_cfg=format_cfg) | |
custom_colors = ['#f9b4ab', '#264e70', '#679186', '#f58231'] | |
for g,c in zip(data.index, custom_colors): | |
# print(g) | |
# print('---') | |
# print(c) | |
# print(data.loc[g].values) | |
radar.plot(data.loc[g].values, label=f"{g}", color=c) | |
# radar.fill(data.loc[g].values, alpha=0.5, color=c) | |
radar.set_title("SARI Test Scores for Revision Models", pad=50) | |
# radar.use_legend(loc='lower center', bbox_to_anchor=(0.5, -0.25), ncol=radar.plot_counter) | |
radar.use_legend(loc='center left', bbox_to_anchor=(0.5, -0.25), ncol=1) | |
plt.savefig("sari_results.png", bbox_inches='tight', dpi=200) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment