Created
November 3, 2020 12:12
-
-
Save lucienne999/be11058e2cf13aea8f4a66a04c1d977c to your computer and use it in GitHub Desktop.
plot rank
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from matplotlib import pyplot as plt | |
supernet = [0.524, 0.569, 0.588, 0.633, 0.438] | |
retrain = [0.734, 0.562, 0.556, 0.581, 0.373] | |
def make_patch_spines_invisible(ax): | |
ax.set_frame_on(True) | |
ax.patch.set_visible(False) | |
for sp in ax.spines.values(): | |
sp.set_visible(False) | |
def pplot(x, y, title): | |
assert isinstance(x, list) | |
assert isinstance(y, list) | |
x = np.array(x) | |
y = np.array(y) | |
arg_x = x.argsort()[::-1] # reverse order | |
arg_y = y.argsort()[::-1] | |
nums = len(x) | |
# raw x,y | |
fig, ax = plt.subplots(sharey=True) | |
ax1 = ax.twinx() | |
ax2 = ax.twinx() | |
ax2.spines["right"].set_position(("axes", 1.2)) | |
make_patch_spines_invisible(ax2) | |
ax2.spines["right"].set_visible(True) | |
ax2.set_ylabel("Retrain") | |
ax.set_ylim(0.5,nums+0.5) | |
ax.set_xlim([-0.05,1.05]) | |
ax.set_xticklabels([]) | |
assert len(x) == len(y) | |
for i in range(len(x)): | |
ax1.plot([0, 1], [arg_x[i]+1, arg_y[i]+1]) | |
ax1.set_xlim([-0.05,1.05]) | |
ax1.set_ylim(0.5,nums+0.5) | |
ax1.set_title(title) | |
# ax1.set_ylabel("SuperNet") | |
plt.show() | |
if __name__ == '__main__': | |
pplot(supernet, retrain, 'RANK') | |
# Top-1 means best |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment