Created
November 14, 2018 04:00
-
-
Save Xparx/33026da63dabb1c200b2602bbae0b95c to your computer and use it in GitHub Desktop.
Minor hack to get clustered grouped heatmap
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scanpy.api as _sc | |
import scanpy.plotting as _scp | |
import seaborn as _sns | |
import numpy as _np | |
import matplotlib.pyplot as plt | |
def clustered_heatmap(data, var_names=None, groupby=None, figsize=(12, 6), cluster=True, use_raw=None, log=False, num_categories=7, cbar=False, vmin=None, vmax=None, **kwds): | |
from anndata import AnnData | |
from scipy.cluster.hierarchy import linkage, dendrogram | |
if isinstance(data, AnnData): | |
adata = data | |
else: | |
raise TypeError("input arguments must be of AnnData type") | |
if groupby is not None: | |
if groupby not in adata.obs_keys(): | |
raise ValueError('groupby has to be a valid observation. Given value: {}, ' | |
'valid observations: {}'.format(groupby, adata.obs_keys())) | |
if use_raw is None and adata.raw is not None: | |
use_raw = True | |
categories, obs_tidy = _scp.anndata._prepare_dataframe(adata, var_names=var_names, groupby=groupby, use_raw=use_raw, log=log, num_categories=num_categories) | |
var_order = dendrogram(linkage(obs_tidy.T), labels=obs_tidy.columns, no_plot=True) | |
obs_tidy = obs_tidy[var_order['ivl']] | |
go = obs_tidy.groupby(groupby) | |
# fig, ax = plt.subplots(nrows=1, ncols=len(categories), figsize=figsize, dpi=75, facecolor='w', edgecolor='k', sharey=True) | |
fig = plt.figure(figsize=figsize, dpi=75, facecolor='w', edgecolor='k') | |
# ax = ax.flatten() | |
if vmin is None: | |
vmin = obs_tidy.min().min() | |
elif vmin is False: | |
vmin = None | |
if vmax is None: | |
vmax = obs_tidy.max().max() | |
elif vmax is False: | |
vmax = None | |
columns = obs_tidy.shape[0] + 10 * len(categories) | |
if cbar: | |
columns = columns + 10 | |
col = 0 | |
ax = [] | |
for i, (c, groupdata) in enumerate(go): | |
group_order = dendrogram(linkage(groupdata), no_plot=True) | |
groupdata = groupdata.T.iloc[:, group_order['leaves'][::-1]] | |
if cbar and (i == len(categories) - 1): | |
addcol = 10 | |
ax.append(plt.subplot2grid((1, columns), (0, col), colspan=addcol, rowspan=1)) | |
# ax = plt.subplot2grid((1, columns), (0, col), colspan=addcol, rowspan=1) | |
_sns.heatmap(groupdata, xticklabels=False, ax=ax[i], cbar=cbar, vmin=vmin, vmax=vmax, **kwds) | |
else: | |
addcol = groupdata.shape[1] | |
ax.append(plt.subplot2grid((1, columns), (0, col), colspan=addcol, rowspan=1)) | |
# ax = plt.subplot2grid((1, columns), (0, col), colspan=addcol, rowspan=1) | |
_sns.heatmap(groupdata, xticklabels=False, ax=ax[i], cbar=False, vmin=vmin, vmax=vmax, **kwds) | |
col = col + addcol + 10 | |
ax[i].set_xlabel(c) | |
ax[i].set_ylabel('') | |
if i > 0: | |
# ax[i].yaxis.set_ticks_position('none') | |
ax[i].set_yticklabels('') | |
ax[i].set_yticks([]) | |
# fig.tight_layout() # Generates a warning even though it works. | |
return ax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment