Skip to content

Instantly share code, notes, and snippets.

@Xparx
Created November 14, 2018 04:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Xparx/33026da63dabb1c200b2602bbae0b95c to your computer and use it in GitHub Desktop.
Save Xparx/33026da63dabb1c200b2602bbae0b95c to your computer and use it in GitHub Desktop.
Minor hack to get clustered grouped heatmap
import scanpy.api as _sc
import scanpy.plotting as _scp
import seaborn as _sns
import numpy as _np
import matplotlib.pyplot as plt
def clustered_heatmap(data, var_names=None, groupby=None, figsize=(12, 6), cluster=True, use_raw=None, log=False, num_categories=7, cbar=False, vmin=None, vmax=None, **kwds):
from anndata import AnnData
from scipy.cluster.hierarchy import linkage, dendrogram
if isinstance(data, AnnData):
adata = data
else:
raise TypeError("input arguments must be of AnnData type")
if groupby is not None:
if groupby not in adata.obs_keys():
raise ValueError('groupby has to be a valid observation. Given value: {}, '
'valid observations: {}'.format(groupby, adata.obs_keys()))
if use_raw is None and adata.raw is not None:
use_raw = True
categories, obs_tidy = _scp.anndata._prepare_dataframe(adata, var_names=var_names, groupby=groupby, use_raw=use_raw, log=log, num_categories=num_categories)
var_order = dendrogram(linkage(obs_tidy.T), labels=obs_tidy.columns, no_plot=True)
obs_tidy = obs_tidy[var_order['ivl']]
go = obs_tidy.groupby(groupby)
# fig, ax = plt.subplots(nrows=1, ncols=len(categories), figsize=figsize, dpi=75, facecolor='w', edgecolor='k', sharey=True)
fig = plt.figure(figsize=figsize, dpi=75, facecolor='w', edgecolor='k')
# ax = ax.flatten()
if vmin is None:
vmin = obs_tidy.min().min()
elif vmin is False:
vmin = None
if vmax is None:
vmax = obs_tidy.max().max()
elif vmax is False:
vmax = None
columns = obs_tidy.shape[0] + 10 * len(categories)
if cbar:
columns = columns + 10
col = 0
ax = []
for i, (c, groupdata) in enumerate(go):
group_order = dendrogram(linkage(groupdata), no_plot=True)
groupdata = groupdata.T.iloc[:, group_order['leaves'][::-1]]
if cbar and (i == len(categories) - 1):
addcol = 10
ax.append(plt.subplot2grid((1, columns), (0, col), colspan=addcol, rowspan=1))
# ax = plt.subplot2grid((1, columns), (0, col), colspan=addcol, rowspan=1)
_sns.heatmap(groupdata, xticklabels=False, ax=ax[i], cbar=cbar, vmin=vmin, vmax=vmax, **kwds)
else:
addcol = groupdata.shape[1]
ax.append(plt.subplot2grid((1, columns), (0, col), colspan=addcol, rowspan=1))
# ax = plt.subplot2grid((1, columns), (0, col), colspan=addcol, rowspan=1)
_sns.heatmap(groupdata, xticklabels=False, ax=ax[i], cbar=False, vmin=vmin, vmax=vmax, **kwds)
col = col + addcol + 10
ax[i].set_xlabel(c)
ax[i].set_ylabel('')
if i > 0:
# ax[i].yaxis.set_ticks_position('none')
ax[i].set_yticklabels('')
ax[i].set_yticks([])
# fig.tight_layout() # Generates a warning even though it works.
return ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment