Skip to content

Instantly share code, notes, and snippets.

@xigrug
Forked from ruxi/jointplot_w_hue.py
Created April 16, 2018 06:51
Show Gist options
  • Save xigrug/7c675cd3739122da2e91a16bd1f67ab4 to your computer and use it in GitHub Desktop.
Save xigrug/7c675cd3739122da2e91a16bd1f67ab4 to your computer and use it in GitHub Desktop.
jointplot_w_hue
__author__ = "lewis.r.liu@gmail.com"
__copyright__ = "Copyright 2018, github.com/ruxi"
__license__ = "MIT"
__version__ = 0.0.1
# update: Mar 5 , 2018
# created: Feb 19, 2018
# desc: seaborn jointplot with 'hue'
# prepared for issue: https://github.com/mwaskom/seaborn/issues/365
"""
jointplots with hue groupings.
minimum working example
-----------------------
iris = sns.load_dataset("iris")
jointplot_w_hue(data=iris, x = 'sepal_length', y = 'sepal_width', hue = 'species')['fig']
changelog
---------
2018 Mar 5: added legends and colormap
2018 Feb 19: gist made
"""
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
sns.set_style('darkgrid')
def jointplot_w_hue(data, x, y, hue=None, colormap = None,
figsize = None, fig = None, scatter_kws=None):
#defaults
if colormap is None:
colormap = sns.color_palette() #['blue','orange']
if figsize is None:
figsize = (5,5)
if fig is None:
fig = plt.figure(figsize = figsize)
if scatter_kws is None:
scatter_kws = dict(alpha=0.4, lw=1)
# derived variables
if hue is None:
return "use normal sns.jointplot"
hue_groups = data[hue].unique()
subdata = dict()
colors = dict()
active_colormap = colormap[0: len(hue_groups)]
legend_mapping = []
for hue_grp, color in zip(hue_groups, active_colormap):
legend_entry = mpatches.Patch(color=color, label=hue_grp)
legend_mapping.append(legend_entry)
subdata[hue_grp] = data[data[hue]==hue_grp]
colors[hue_grp] = color
# canvas setup
grid = gridspec.GridSpec(2, 2,
width_ratios=[4, 1],
height_ratios=[1, 4],
hspace = 0, wspace = 0
)
ax_main = plt.subplot(grid[1,0])
ax_xhist = plt.subplot(grid[0,0], sharex=ax_main)
ax_yhist = plt.subplot(grid[1,1])#, sharey=ax_main)
## plotting
# histplot x-axis
for hue_grp in hue_groups:
sns.distplot(subdata[hue_grp][x], color = colors[hue_grp]
, ax = ax_xhist)
# histplot y-axis
for hue_grp in hue_groups:
sns.distplot(subdata[hue_grp][y], color = colors[hue_grp]
, ax = ax_yhist, vertical=True)
# main scatterplot
# note: must be after the histplots else ax_yhist messes up
for hue_grp in hue_groups:
sns.regplot(data = subdata[hue_grp], fit_reg=False,
x = x, y = y, ax = ax_main, color = colors[hue_grp]
, scatter_kws=scatter_kws
)
# despine
for myax in [ax_yhist, ax_xhist]:
sns.despine(ax = myax, bottom=False, top=True, left = False, right = True
, trim = False)
plt.setp(myax.get_xticklabels(), visible=False)
plt.setp(myax.get_yticklabels(), visible=False)
# topright
ax_legend = plt.subplot(grid[0,1])#, sharey=ax_main)
plt.setp(ax_legend.get_xticklabels(), visible=False)
plt.setp(ax_legend.get_yticklabels(), visible=False)
ax_legend.legend(handles=legend_mapping)
plt.close()
return dict(fig = fig, gridspec = grid)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment