Skip to content

Instantly share code, notes, and snippets.

@ryanpeach
Created April 23, 2017 01:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ryanpeach/7fc3e82acf33894815ebd0b645fa8e63 to your computer and use it in GitHub Desktop.
Save ryanpeach/7fc3e82acf33894815ebd0b645fa8e63 to your computer and use it in GitHub Desktop.
Useful if you have a dataframe, with a certain number of unique categories or values in x and y, and you want to compare them visually to a third continuous value.
def grid_plot(x_label, y_label, z_label, data, ax=None):
""" Useful if you have a dataframe, with a certain number of unique categories or values in x and y,
and you want to compare them visually to a third continuous value. """
if ax is None:
ax = plt.gca()
x_val = np.sort(np.unique(data[x_label]))
y_val = np.sort(np.unique(data[y_label]))
x_idx = np.arange(len(x_val))
y_idx = np.arange(len(y_val))
yyi, xxi = np.meshgrid(y_idx, x_idx)
yyv, xxv = np.meshgrid(y_val, x_val)
zzv = np.zeros((len(y_val), len(x_val)))
for xv, yv, xi, yi in zip(xxv.flatten(),yyv.flatten(),
xxi.flatten(),yyi.flatten()):
z_idx = ((data[x_label] == xv) & (data[y_label] == yv)).argmax()
zzv[yi,xi] = data[z_label][z_idx]
img = ax.imshow(zzv)
ax.set_xticks(x_idx)
ax.set_yticks(y_idx)
ax.set_xticklabels(x_val)
ax.set_yticklabels(y_val)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
cbar = plt.colorbar(img)
cbar.ax.set_ylabel(z_label)
return ax, cbar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment