Skip to content

Instantly share code, notes, and snippets.

@grst
Created September 6, 2023 14:37
Show Gist options
  • Save grst/424e3e24bf244820000c33a823a47ec1 to your computer and use it in GitHub Desktop.
Save grst/424e3e24bf244820000c33a823a47ec1 to your computer and use it in GitHub Desktop.
from functools import partial
import math
from typing import Literal
from datashader.mpl_ext import dsshow, alpha_colormap
import datashader as ds
from joblib import parallel_backend
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
from anndata import AnnData
import numpy as np
from matplotlib import patheffects
import datashader.transfer_functions as tf
def embedding(
adata: AnnData,
key: str = "X_umap",
*,
x="x",
y="y",
color,
ncol=3,
panel_size=4,
cmap="viridis",
show=True,
wspace=1,
hspace=0.5,
legend_loc: Literal["on data", "right margin"] = "right margin",
add_outline=True,
spread=0,
groups=None,
):
adata.strings_to_categoricals()
if isinstance(color, str):
color = [color]
nrow = math.ceil(len(color) / float(ncol))
fig = plt.figure(figsize=(ncol * panel_size, nrow * panel_size))
grid = ImageGrid(
fig,
111,
nrows_ncols=(nrow, ncol),
axes_pad=(wspace, hspace),
share_all=True,
cbar_location="right",
cbar_mode="each",
cbar_size="5%",
cbar_pad="2%",
)
obsm_df = adata.obsm[key]
if not isinstance(obsm_df, pd.DataFrame):
obsm_df = pd.DataFrame(obsm_df, columns=[x, y], index=adata.obs_names)
obsm_df = sc.get.obs_df(adata, color).join(obsm_df)
ax_limit = np.max(np.abs(obsm_df.loc[:, [x, y]].values))
for i, c in enumerate(color):
if add_outline:
dsshow(
obsm_df,
ds.Point(x, y),
ds.count(),
shade_hook=partial(tf.spread, px=spread + 1),
ax=grid[i],
cmap=["#AAAAAA", "#AAAAAA"],
)
if obsm_df[c].dtype.name == "category":
artist = dsshow(
obsm_df,
ds.Point(x, y),
ds.count_cat(c),
ax=grid[i],
shade_hook=partial(tf.spread, px=spread),
color_key=_get_palette_from_anndata(adata, c, groups),
)
if legend_loc == "right margin":
_add_legend_right_margin(grid[i], artist)
elif legend_loc == "on data":
_add_legend_on_data(grid[i], obsm_df, x, y, c)
grid.cbar_axes[i].set_visible(False)
else:
artist = dsshow(
obsm_df,
ds.Point(x, y),
ds.mean(c),
cmap=cmap,
aspect="equal",
shade_hook=partial(tf.spread, px=spread),
ax=grid[i],
)
plt.colorbar(artist, cax=grid.cbar_axes[i])
grid[i].set_title(c)
grid[i].set_xlim(-ax_limit, ax_limit)
grid[i].set_ylim(-ax_limit, ax_limit)
# do not show remaining axes in grid
for i in range(len(color), nrow * ncol):
grid[i].set_visible(False)
grid.cbar_axes[i].set_visible(False)
if show:
plt.show()
def _add_legend_on_data(ax, obsm_df, x, y, c):
all_pos = (
obsm_df.groupby(c, observed=True).median()
# Have to sort_index since if observed=True and categorical is unordered
# the order of values in .index is undefined. Related issue:
# https://github.com/pandas-dev/pandas/issues/25167
.sort_index()
)
for label, row in all_pos.iterrows():
ax.text(
row[x],
row[y],
label,
weight="bold",
verticalalignment="center",
horizontalalignment="center",
fontsize=None,
path_effects=[patheffects.withStroke(linewidth=1, foreground="w")],
)
def _add_legend_right_margin(ax, artist):
ax.legend(
handles=artist.get_legend_elements(),
bbox_to_anchor=(1.04, 1),
loc="upper left",
)
def _get_palette_from_anndata(adata, column, groups):
"""Get a dictionary with the color key from adata.uns. If it doesn't exist yet,
use scanpy code to assign a default palette."""
if f"{column}_colors" not in adata.uns:
sc.pl._utils._set_default_colors_for_categorical_obs(adata, column)
color_dict = {
k: color
for k, color in zip(
adata.obs[column].cat.categories, adata.uns[f"{column}_colors"]
)
}
# set groups that are note selected to grey
if groups is not None:
for k in color_dict:
if k not in groups:
color_dict[k] = "#999999"
return color_dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment