Skip to content

Instantly share code, notes, and snippets.

@gokceneraslan
Last active September 20, 2023 20:54
Show Gist options
  • Save gokceneraslan/f953b697ea09b9a8ff890a8707775498 to your computer and use it in GitHub Desktop.
Save gokceneraslan/f953b697ea09b9a8ff890a8707775498 to your computer and use it in GitHub Desktop.
Running SingleR from Python
import scanpy as sc
# numpy et al.
import numpy as np
import scipy.sparse as sp
import pandas as pd
import gc
# R integration
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri, numpy2ri, r
from rpy2.robjects.vectors import StrVector, FloatVector, ListVector
import rpy2.robjects as ro
import anndata2ri
def predict_cell_types(adata,
use_raw=True,
species='mouse',
ref_name=None,
ref_label_column='label.fine',
cluster_key='leiden',
obs_out_key='predicted_cell_types',
uns_out_key='cell_type_prediction',
**kwds):
r('BiocParallel::register(BiocParallel::SerialParam())')
s = importr('SingleR')
if ref_name is None:
if species == 'mouse':
ref_names = ['MouseRNAseqData', 'ImmGenData']
else:
ref_names = ['HumanPrimaryCellAtlasData', 'MonacoImmuneData']
refs = [s.__dict__[ref_name]() for ref_name in ref_names]
ref_genes = StrVector(set.intersection(*[set(r['rownames'](d)) for d in refs]))
ref = r('cbind')(*[r('`[`')(d, ref_genes) for d in refs]) # merge references
else:
if isinstance(ref_name, str):
ref = r(f'scRNAseq::{ref_name}')()
ref = r('scater::logNormCounts')(ref)
else:
ref = ref_name
ref_genes = r['rownames'](ref)
ad = adata.raw if use_raw else adata
obs = adata.obs
common_genes = sorted(list(set(ad.var_names) & set(ref_genes)))
assert len(common_genes) > 1000, 'Not enough genes overlapping with ref SingleR datasets...'
ref = r('`[`')(ref, ro.vectors.StrVector(common_genes))
mat = ad[:, common_genes].X.T.copy()
if sp.issparse(mat):
mat = anndata2ri.scipy2ri.py2rpy(sp.csr_matrix(mat))
else:
mat = numpy2ri.py2rpy(mat)
mat = r("`rownames<-`")(mat, ro.vectors.StrVector(ad[:, common_genes].var_names))
mat = r("`colnames<-`")(mat, ro.vectors.StrVector(obs.index)) # TODO: really needed?
clusters = ro.vectors.StrVector(obs[cluster_key].values.tolist())
par = r('BiocParallel::SerialParam')()
labels = s.SingleR(test=mat,
ref=ref,
labels=r('`$`')(ref, ref_label_column), # use label.main too
method='cluster',
clusters=clusters, BPPARAM=par, **kwds)
labels = pandas2ri.rpy2py(r('as.data.frame')(labels))
adata.obs[obs_out_key] = ''
for cluster in adata.obs[cluster_key].cat.categories:
adata.obs.loc[adata.obs[cluster_key] == cluster, obs_out_key] = labels.loc[cluster]['pruned.labels']
adata.uns[uns_out_key] = labels['pruned.labels'].to_dict()
@xiachenrui
Copy link

I think you need instal R packeage 'BiocParallel' first

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment