Created
September 6, 2023 14:37
-
-
Save grst/424e3e24bf244820000c33a823a47ec1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import math | |
from typing import Literal | |
from datashader.mpl_ext import dsshow, alpha_colormap | |
import datashader as ds | |
from joblib import parallel_backend | |
from mpl_toolkits.axes_grid1 import ImageGrid | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import scanpy as sc | |
from anndata import AnnData | |
import numpy as np | |
from matplotlib import patheffects | |
import datashader.transfer_functions as tf | |
def embedding( | |
adata: AnnData, | |
key: str = "X_umap", | |
*, | |
x="x", | |
y="y", | |
color, | |
ncol=3, | |
panel_size=4, | |
cmap="viridis", | |
show=True, | |
wspace=1, | |
hspace=0.5, | |
legend_loc: Literal["on data", "right margin"] = "right margin", | |
add_outline=True, | |
spread=0, | |
groups=None, | |
): | |
adata.strings_to_categoricals() | |
if isinstance(color, str): | |
color = [color] | |
nrow = math.ceil(len(color) / float(ncol)) | |
fig = plt.figure(figsize=(ncol * panel_size, nrow * panel_size)) | |
grid = ImageGrid( | |
fig, | |
111, | |
nrows_ncols=(nrow, ncol), | |
axes_pad=(wspace, hspace), | |
share_all=True, | |
cbar_location="right", | |
cbar_mode="each", | |
cbar_size="5%", | |
cbar_pad="2%", | |
) | |
obsm_df = adata.obsm[key] | |
if not isinstance(obsm_df, pd.DataFrame): | |
obsm_df = pd.DataFrame(obsm_df, columns=[x, y], index=adata.obs_names) | |
obsm_df = sc.get.obs_df(adata, color).join(obsm_df) | |
ax_limit = np.max(np.abs(obsm_df.loc[:, [x, y]].values)) | |
for i, c in enumerate(color): | |
if add_outline: | |
dsshow( | |
obsm_df, | |
ds.Point(x, y), | |
ds.count(), | |
shade_hook=partial(tf.spread, px=spread + 1), | |
ax=grid[i], | |
cmap=["#AAAAAA", "#AAAAAA"], | |
) | |
if obsm_df[c].dtype.name == "category": | |
artist = dsshow( | |
obsm_df, | |
ds.Point(x, y), | |
ds.count_cat(c), | |
ax=grid[i], | |
shade_hook=partial(tf.spread, px=spread), | |
color_key=_get_palette_from_anndata(adata, c, groups), | |
) | |
if legend_loc == "right margin": | |
_add_legend_right_margin(grid[i], artist) | |
elif legend_loc == "on data": | |
_add_legend_on_data(grid[i], obsm_df, x, y, c) | |
grid.cbar_axes[i].set_visible(False) | |
else: | |
artist = dsshow( | |
obsm_df, | |
ds.Point(x, y), | |
ds.mean(c), | |
cmap=cmap, | |
aspect="equal", | |
shade_hook=partial(tf.spread, px=spread), | |
ax=grid[i], | |
) | |
plt.colorbar(artist, cax=grid.cbar_axes[i]) | |
grid[i].set_title(c) | |
grid[i].set_xlim(-ax_limit, ax_limit) | |
grid[i].set_ylim(-ax_limit, ax_limit) | |
# do not show remaining axes in grid | |
for i in range(len(color), nrow * ncol): | |
grid[i].set_visible(False) | |
grid.cbar_axes[i].set_visible(False) | |
if show: | |
plt.show() | |
def _add_legend_on_data(ax, obsm_df, x, y, c): | |
all_pos = ( | |
obsm_df.groupby(c, observed=True).median() | |
# Have to sort_index since if observed=True and categorical is unordered | |
# the order of values in .index is undefined. Related issue: | |
# https://github.com/pandas-dev/pandas/issues/25167 | |
.sort_index() | |
) | |
for label, row in all_pos.iterrows(): | |
ax.text( | |
row[x], | |
row[y], | |
label, | |
weight="bold", | |
verticalalignment="center", | |
horizontalalignment="center", | |
fontsize=None, | |
path_effects=[patheffects.withStroke(linewidth=1, foreground="w")], | |
) | |
def _add_legend_right_margin(ax, artist): | |
ax.legend( | |
handles=artist.get_legend_elements(), | |
bbox_to_anchor=(1.04, 1), | |
loc="upper left", | |
) | |
def _get_palette_from_anndata(adata, column, groups): | |
"""Get a dictionary with the color key from adata.uns. If it doesn't exist yet, | |
use scanpy code to assign a default palette.""" | |
if f"{column}_colors" not in adata.uns: | |
sc.pl._utils._set_default_colors_for_categorical_obs(adata, column) | |
color_dict = { | |
k: color | |
for k, color in zip( | |
adata.obs[column].cat.categories, adata.uns[f"{column}_colors"] | |
) | |
} | |
# set groups that are note selected to grey | |
if groups is not None: | |
for k in color_dict: | |
if k not in groups: | |
color_dict[k] = "#999999" | |
return color_dict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment