Skip to content

Instantly share code, notes, and snippets.

@ruxi
Last active March 13, 2021 23:39
Show Gist options
  • Save ruxi/ff0e9255d74a3c187667627214e1f5fa to your computer and use it in GitHub Desktop.
Save ruxi/ff0e9255d74a3c187667627214e1f5fa to your computer and use it in GitHub Desktop.
jointplot_w_hue
__author__ = "lewis.r.liu@gmail.com"
__copyright__ = "Copyright 2020, 2018, https://gist.github.com/ruxi/ff0e9255d74a3c187667627214e1f5fa"
__license__ = "MIT"
__version__ = "0.0.2"
# update: June 13, 2020
# created: Feb 19, 2018
# desc: seaborn jointplot with 'hue'
# prepared for issue: https://github.com/mwaskom/seaborn/issues/365
# resolved (22 Aug 2020): https://github.com/mwaskom/seaborn/pull/2210
import seaborn as sns
import matplotlib.pyplot as plt
# import matplotlib.lines as mlines
# import matplotlib.patches as mpatches
def plot_jointgrid_hue(data, x, y, hue
, cmap = None #['green', 'orange']
, alphas: list = None #[0.2, 0.5]
, alpha = None
, marker_map: list = None #['x', '+']
, marker = None
, map_plot_margin_x = sns.distplot
, map_plot_margin_y = sns.distplot
, map_plot_joint = sns.scatterplot
, kw_jointgrid = dict()
, kw_margins = dict(kde = True)
, kw_scatter = dict()
):
"""
seaborn jointgrid with hue
returns
-------
seaborn.axisgrid.JointGrid
minimum working example
-----------------------
iris = sns.load_dataset("iris")
g = plot_jointgrid_hue(data=iris, x = 'sepal_length', y = 'sepal_width', hue = 'species')
g.fig
changelog
---------
2020 June 13: Returns JointGrid as a base instead of GridSpec.
Include the option to use different alphas and markers
for each hue group.
2018 Mar 5: added legends and colormap
2018 Feb 19: gist made
"""
#+------------------+
#| default mappings |
#+------------------+
if cmap is None:
cmap = sns.color_palette()
if marker is None:
marker = "o"
if alpha is None:
alpha = 0.5
#+------------------+
#| intialize grid |
#+------------------+
grid = sns.JointGrid(data = data, x = x, y = y, **kw_jointgrid)
i = -1
legend_handles = []
for k, subset in data.groupby(hue):
i +=1
mapped_params = dict(marker = marker_map[i] if type(marker_map)==list else marker
, alpha = alphas[i] if type(alphas)==list else alpha
, color = cmap[i]
)
map_plot_margin_x( a = subset[x]
, ax=grid.ax_marg_x
, color = mapped_params['color']
, **kw_margins
)
map_plot_margin_y( a = subset[y]
, ax=grid.ax_marg_y
, vertical=True
, color = mapped_params['color']
, **kw_margins
)
map_plot_joint(data = subset
, x = subset[x]
, y = subset[y]
, ax = grid.ax_joint
, **mapped_params
, **kw_scatter
)
#+----------------+
#| legend handles |
#+----------------+
# https://matplotlib.org/tutorials/intermediate/legend_guide.html
legend_entry, = plt.plot([0]
, mapped_params['marker']
, color = mapped_params['color']
, label = str(k))
legend_handles.append(legend_entry)
#+-----------------+
#| populate legend |
#+-----------------+
grid.fig.legend(title=hue, handles = legend_handles)
plt.close()
return grid
@xigrug
Copy link

xigrug commented Apr 14, 2018

Nice work!Can trend line be added to the figure?or use different markers?

@prubbens
Copy link

prubbens commented Oct 22, 2018

I get the following error when choosing a colormap myself:
File "plot_tsne_growthphase_kde.py", line 37, in jointplot_w_hue active_colormap = colormap[0:len(hue_groups)] TypeError: unhashable type: 'slice'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment